diff --git a/src/amd/vulkan/radv_shader_args.c b/src/amd/vulkan/radv_shader_args.c index 62aabfc0f1e..9890f2cdac6 100644 --- a/src/amd/vulkan/radv_shader_args.c +++ b/src/amd/vulkan/radv_shader_args.c @@ -118,11 +118,11 @@ count_vs_user_sgprs(struct radv_shader_args *args) } static unsigned -count_ngg_sgprs(struct radv_shader_args *args, gl_shader_stage stage) +count_ngg_sgprs(struct radv_shader_args *args, bool has_api_gs) { unsigned count = 0; - if (stage == MESA_SHADER_GEOMETRY) + if (has_api_gs) count += 1; /* ngg_gs_state */ if (args->shader_info->has_ngg_culling) count += 5; /* ngg_culling_settings + 4x ngg_viewport_* */ @@ -174,7 +174,7 @@ allocate_inline_push_consts(struct radv_shader_args *args, struct user_sgpr_info static void allocate_user_sgprs(struct radv_shader_args *args, gl_shader_stage stage, bool has_previous_stage, - gl_shader_stage previous_stage, bool needs_view_index, + gl_shader_stage previous_stage, bool needs_view_index, bool has_api_gs, struct user_sgpr_info *user_sgpr_info) { uint8_t user_sgpr_count = 0; @@ -199,8 +199,6 @@ allocate_user_sgprs(struct radv_shader_args *args, gl_shader_stage stage, bool h case MESA_SHADER_VERTEX: if (!args->is_gs_copy_shader) user_sgpr_count += count_vs_user_sgprs(args); - if (args->shader_info->is_ngg) - user_sgpr_count += count_ngg_sgprs(args, stage); break; case MESA_SHADER_TESS_CTRL: if (has_previous_stage) { @@ -209,13 +207,11 @@ allocate_user_sgprs(struct radv_shader_args *args, gl_shader_stage stage, bool h } break; case MESA_SHADER_TESS_EVAL: - if (args->shader_info->is_ngg) - user_sgpr_count += count_ngg_sgprs(args, stage); break; case MESA_SHADER_GEOMETRY: if (has_previous_stage) { if (args->shader_info->is_ngg) - user_sgpr_count += count_ngg_sgprs(args, stage); + user_sgpr_count += count_ngg_sgprs(args, has_api_gs); if (previous_stage == MESA_SHADER_VERTEX) { user_sgpr_count += count_vs_user_sgprs(args); @@ -370,9 +366,9 @@ declare_tes_input_vgprs(struct radv_shader_args *args) } static void -declare_ngg_sgprs(struct radv_shader_args *args, gl_shader_stage stage) +declare_ngg_sgprs(struct radv_shader_args *args, bool has_api_gs) { - if (stage == MESA_SHADER_GEOMETRY) { + if (has_api_gs) { ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->ngg_gs_state); } @@ -435,10 +431,9 @@ set_vs_specific_input_locs(struct radv_shader_args *args, gl_shader_stage stage, } static void -set_ngg_sgprs_locs(struct radv_shader_args *args, gl_shader_stage stage, uint8_t *user_sgpr_idx) +set_ngg_sgprs_locs(struct radv_shader_args *args, uint8_t *user_sgpr_idx) { - if (stage == MESA_SHADER_GEOMETRY) { - assert(args->ngg_gs_state.used); + if (args->ngg_gs_state.used) { set_loc_shader(args, AC_UD_NGG_GS_STATE, user_sgpr_idx, 1); } @@ -465,6 +460,7 @@ radv_declare_shader_args(struct radv_shader_args *args, gl_shader_stage stage, { struct user_sgpr_info user_sgpr_info; bool needs_view_index = needs_view_index_sgpr(args, stage); + bool has_api_gs = stage == MESA_SHADER_GEOMETRY; if (args->options->chip_class >= GFX10) { if (is_pre_gs_stage(stage) && args->shader_info->is_ngg) { @@ -481,7 +477,7 @@ radv_declare_shader_args(struct radv_shader_args *args, gl_shader_stage stage, args->shader_info->user_sgprs_locs.shader_data[i].sgpr_idx = -1; allocate_user_sgprs(args, stage, has_previous_stage, previous_stage, needs_view_index, - &user_sgpr_info); + has_api_gs, &user_sgpr_info); if (args->options->explicit_scratch_args) { ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_DESC_PTR, &args->ring_offsets); @@ -520,6 +516,9 @@ radv_declare_shader_args(struct radv_shader_args *args, gl_shader_stage stage, ac_add_arg(&args->ac, AC_ARG_VGPR, 3, AC_ARG_INT, &args->ac.local_invocation_ids); break; case MESA_SHADER_VERTEX: + /* NGG is handled by the GS case */ + assert(!args->shader_info->is_ngg); + declare_global_input_sgprs(args, &user_sgpr_info); declare_vs_specific_input_sgprs(args, stage, has_previous_stage, previous_stage); @@ -539,9 +538,6 @@ radv_declare_shader_args(struct radv_shader_args *args, gl_shader_stage stage, if (args->options->explicit_scratch_args) { ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->ac.scratch_offset); } - if (args->shader_info->is_ngg) { - declare_ngg_sgprs(args, stage); - } declare_vs_input_vgprs(args); break; @@ -585,6 +581,9 @@ radv_declare_shader_args(struct radv_shader_args *args, gl_shader_stage stage, } break; case MESA_SHADER_TESS_EVAL: + /* NGG is handled by the GS case */ + assert(!args->shader_info->is_ngg); + declare_global_input_sgprs(args, &user_sgpr_info); if (needs_view_index) @@ -601,9 +600,6 @@ radv_declare_shader_args(struct radv_shader_args *args, gl_shader_stage stage, if (args->options->explicit_scratch_args) { ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->ac.scratch_offset); } - if (args->shader_info->is_ngg) { - declare_ngg_sgprs(args, stage); - } declare_tes_input_vgprs(args); break; case MESA_SHADER_GEOMETRY: @@ -633,7 +629,7 @@ radv_declare_shader_args(struct radv_shader_args *args, gl_shader_stage stage, } if (args->shader_info->is_ngg) { - declare_ngg_sgprs(args, stage); + declare_ngg_sgprs(args, has_api_gs); } ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.gs_vtx_offset[0]); @@ -729,8 +725,6 @@ radv_declare_shader_args(struct radv_shader_args *args, gl_shader_stage stage, set_vs_specific_input_locs(args, stage, has_previous_stage, previous_stage, &user_sgpr_idx); if (args->ac.view_index.used) set_loc_shader(args, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1); - if (args->shader_info->is_ngg) - set_ngg_sgprs_locs(args, stage, &user_sgpr_idx); break; case MESA_SHADER_TESS_CTRL: set_vs_specific_input_locs(args, stage, has_previous_stage, previous_stage, &user_sgpr_idx); @@ -740,8 +734,6 @@ radv_declare_shader_args(struct radv_shader_args *args, gl_shader_stage stage, case MESA_SHADER_TESS_EVAL: if (args->ac.view_index.used) set_loc_shader(args, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1); - if (args->shader_info->is_ngg) - set_ngg_sgprs_locs(args, stage, &user_sgpr_idx); break; case MESA_SHADER_GEOMETRY: if (has_previous_stage) { @@ -753,7 +745,7 @@ radv_declare_shader_args(struct radv_shader_args *args, gl_shader_stage stage, set_loc_shader(args, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1); if (args->shader_info->is_ngg) - set_ngg_sgprs_locs(args, stage, &user_sgpr_idx); + set_ngg_sgprs_locs(args, &user_sgpr_idx); break; case MESA_SHADER_FRAGMENT: break;