ac/nir: fix lowering subgroup ID for compute shaders on GFX12

This is lowered in backend compilers (LLVM or ACO) because it needs
to access ttmp registers which aren't exposed to NIR.

Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32940>
This commit is contained in:
Samuel Pitoiset
2025-01-03 01:54:00 -08:00
committed by Marge Bot
parent bc1374355b
commit 44ba856089

View File

@@ -159,10 +159,7 @@ load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b)
if (s->workgroup_size <= s->wave_size) {
return nir_imm_int(b, 0);
} else if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
if (s->gfx_level >= GFX12)
return false;
assert(s->args->tg_size.used);
assert(s->gfx_level < GFX12 && s->args->tg_size.used);
if (s->gfx_level >= GFX10_3) {
return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 20, 5);
@@ -198,6 +195,8 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
switch (intrin->intrinsic) {
case nir_intrinsic_load_subgroup_id:
if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER)
return false; /* Lowered in backend compilers. */
replacement = load_subgroup_id_lowered(s, b);
break;
case nir_intrinsic_load_num_subgroups: {
@@ -556,8 +555,16 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
nir_def *wave_id_mul_64 = nir_iand_imm(b, ac_nir_load_arg(b, s->args, s->args->tg_size), 0xfc0);
replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), wave_id_mul_64);
} else {
nir_def *subgroup_id;
if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER) {
subgroup_id = nir_load_subgroup_id(b);
} else {
subgroup_id = load_subgroup_id_lowered(s, b);
}
replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size),
nir_imul_imm(b, load_subgroup_id_lowered(s, b), s->wave_size));
nir_imul_imm(b, subgroup_id, s->wave_size));
}
break;
case nir_intrinsic_load_subgroup_invocation: