ac/nir: extract a load_subgroup_id lowered helper

this will be used in the next commit

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32782>
This commit is contained in:
Marek Olšák
2024-12-26 04:05:46 -05:00
committed by Marge Bot
parent 85ce311a36
commit 433ca6ba38
5 changed files with 45 additions and 32 deletions

View File

@@ -99,6 +99,8 @@ typedef struct {
const struct ac_shader_args *const args;
const enum amd_gfx_level gfx_level;
bool has_ls_vgpr_init_bug;
unsigned wave_size;
unsigned workgroup_size;
const enum ac_hw_stage hw_stage;
nir_def *vertex_id;
@@ -123,6 +125,38 @@ preload_arg(lower_intrinsics_to_args_state *s, nir_function_impl *impl, struct a
return value;
}
static nir_def *
load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b)
{
if (s->workgroup_size <= s->wave_size) {
return nir_imm_int(b, 0);
} else if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
if (s->gfx_level >= GFX12)
return false;
assert(s->args->tg_size.used);
if (s->gfx_level >= GFX10_3) {
return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 20, 5);
} else {
/* GFX6-10 don't actually support a wave id, but we can
* use the ordered id because ORDERED_APPEND_* is set to
* zero in the compute dispatch initiatior.
*/
return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 6, 6);
}
} else if (s->hw_stage == AC_HW_HULL_SHADER && s->gfx_level >= GFX11) {
assert(s->args->tcs_wave_id.used);
return ac_nir_unpack_arg(b, s->args, s->args->tcs_wave_id, 0, 3);
} else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
assert(s->args->merged_wave_info.used);
return ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 24, 4);
} else {
return nir_imm_int(b, 0);
}
}
static bool
lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
{
@@ -135,35 +169,9 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
b->cursor = nir_after_instr(&intrin->instr);
switch (intrin->intrinsic) {
case nir_intrinsic_load_subgroup_id: {
if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
if (s->gfx_level >= GFX12)
return false;
assert(s->args->tg_size.used);
if (s->gfx_level >= GFX10_3) {
replacement = ac_nir_unpack_arg(b, s->args, s->args->tg_size, 20, 5);
} else {
/* GFX6-10 don't actually support a wave id, but we can
* use the ordered id because ORDERED_APPEND_* is set to
* zero in the compute dispatch initiatior.
*/
replacement = ac_nir_unpack_arg(b, s->args, s->args->tg_size, 6, 6);
}
} else if (s->hw_stage == AC_HW_HULL_SHADER && s->gfx_level >= GFX11) {
assert(s->args->tcs_wave_id.used);
replacement = ac_nir_unpack_arg(b, s->args, s->args->tcs_wave_id, 0, 3);
} else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
assert(s->args->merged_wave_info.used);
replacement = ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 24, 4);
} else {
replacement = nir_imm_int(b, 0);
}
case nir_intrinsic_load_subgroup_id:
replacement = load_subgroup_id_lowered(s, b);
break;
}
case nir_intrinsic_load_num_subgroups: {
if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
assert(s->args->tg_size.used);
@@ -381,12 +389,15 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
bool
ac_nir_lower_intrinsics_to_args(nir_shader *shader, const enum amd_gfx_level gfx_level,
bool has_ls_vgpr_init_bug, const enum ac_hw_stage hw_stage,
unsigned wave_size, unsigned workgroup_size,
const struct ac_shader_args *ac_args)
{
lower_intrinsics_to_args_state state = {
.gfx_level = gfx_level,
.hw_stage = hw_stage,
.has_ls_vgpr_init_bug = has_ls_vgpr_init_bug,
.wave_size = wave_size,
.workgroup_size = workgroup_size,
.args = ac_args,
};

View File

@@ -78,6 +78,7 @@ bool ac_nir_lower_sin_cos(nir_shader *shader);
bool ac_nir_lower_intrinsics_to_args(nir_shader *shader, const enum amd_gfx_level gfx_level,
bool has_ls_vgpr_init_bug, const enum ac_hw_stage hw_stage,
unsigned wave_size, unsigned workgroup_size,
const struct ac_shader_args *ac_args);
bool ac_nir_optimize_outputs(nir_shader *nir, bool sprite_tex_disallowed,

View File

@@ -517,7 +517,8 @@ radv_postprocess_nir(struct radv_device *device, const struct radv_graphics_stat
NIR_PASS(_, stage->nir, ac_nir_lower_global_access);
NIR_PASS_V(stage->nir, ac_nir_lower_intrinsics_to_args, gfx_level,
pdev->info.has_ls_vgpr_init_bug && gfx_state && !gfx_state->vs.has_prolog,
radv_select_hw_stage(&stage->info, gfx_level), &stage->args.ac);
radv_select_hw_stage(&stage->info, gfx_level), stage->info.wave_size, stage->info.workgroup_size,
&stage->args.ac);
NIR_PASS_V(stage->nir, radv_nir_lower_abi, gfx_level, stage, gfx_state, pdev->info.address32_hi);
radv_optimize_nir_algebraic(
stage->nir, io_to_mem || lowered_ngg || stage->stage == MESA_SHADER_COMPUTE || stage->stage == MESA_SHADER_TASK,

View File

@@ -2278,7 +2278,7 @@ radv_create_gs_copy_shader(struct radv_device *device, struct vk_pipeline_cache
gs_copy_stage.info.inline_push_constant_mask = gs_copy_stage.args.ac.inline_push_const_mask;
NIR_PASS_V(nir, ac_nir_lower_intrinsics_to_args, pdev->info.gfx_level, pdev->info.has_ls_vgpr_init_bug,
AC_HW_VERTEX_SHADER, &gs_copy_stage.args.ac);
AC_HW_VERTEX_SHADER, 64, 64, &gs_copy_stage.args.ac);
NIR_PASS_V(nir, radv_nir_lower_abi, pdev->info.gfx_level, &gs_copy_stage, gfx_state, pdev->info.address32_hi);
struct radv_graphics_pipeline_key key = {0};

View File

@@ -2596,7 +2596,7 @@ static struct nir_shader *si_get_nir_shader(struct si_shader *shader, struct si_
NIR_PASS(progress, nir, ac_nir_lower_intrinsics_to_args, sel->screen->info.gfx_level,
sel->screen->info.has_ls_vgpr_init_bug,
si_select_hw_stage(nir->info.stage, key, sel->screen->info.gfx_level),
&args->ac);
shader->wave_size, si_get_max_workgroup_size(shader), &args->ac);
if (progress) {
si_nir_opts(sel->screen, nir, false);
@@ -2766,7 +2766,7 @@ si_nir_generate_gs_copy_shader(struct si_screen *sscreen,
NIR_PASS_V(nir, si_nir_lower_abi, shader, &args);
NIR_PASS_V(nir, ac_nir_lower_intrinsics_to_args, sscreen->info.gfx_level,
sscreen->info.has_ls_vgpr_init_bug, AC_HW_VERTEX_SHADER, &args.ac);
sscreen->info.has_ls_vgpr_init_bug, AC_HW_VERTEX_SHADER, 64, 64, &args.ac);
si_nir_opts(gs_selector->screen, nir, false);