ac/nir: set upper ranges for range analysis while lowering system values

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32782>
This commit is contained in:
Marek Olšák
2024-12-29 21:12:40 -05:00
committed by Marge Bot
parent 0d5b03f2b9
commit 5dd9171765

View File

@@ -75,11 +75,9 @@ ac_nir_store_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct ac
nir_store_vector_arg_amd(b, val, .base = arg.arg_index); nir_store_vector_arg_amd(b, val, .base = arg.arg_index);
} }
nir_def * static nir_def *
ac_nir_unpack_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct ac_arg arg, ac_nir_unpack_value(nir_builder *b, nir_def *value, unsigned rshift, unsigned bitwidth)
unsigned rshift, unsigned bitwidth)
{ {
nir_def *value = ac_nir_load_arg(b, ac_args, arg);
if (rshift == 0 && bitwidth == 32) if (rshift == 0 && bitwidth == 32)
return value; return value;
else if (rshift == 0) else if (rshift == 0)
@@ -90,6 +88,14 @@ ac_nir_unpack_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct a
return nir_ubfe_imm(b, value, rshift, bitwidth); return nir_ubfe_imm(b, value, rshift, bitwidth);
} }
nir_def *
ac_nir_unpack_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct ac_arg arg,
unsigned rshift, unsigned bitwidth)
{
nir_def *value = ac_nir_load_arg(b, ac_args, arg);
return ac_nir_unpack_value(b, value, rshift, bitwidth);
}
static bool static bool
is_sin_cos(const nir_instr *instr, UNUSED const void *_) is_sin_cos(const nir_instr *instr, UNUSED const void *_)
{ {
@@ -232,50 +238,64 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
ac_nir_load_arg(b, s->args, s->args->frag_pos[2]), ac_nir_load_arg(b, s->args, s->args->frag_pos[2]),
ac_nir_load_arg(b, s->args, s->args->frag_pos[3])); ac_nir_load_arg(b, s->args, s->args->frag_pos[3]));
break; break;
case nir_intrinsic_load_local_invocation_id: case nir_intrinsic_load_local_invocation_id: {
unsigned num_bits[3];
nir_def *vec[3];
for (unsigned i = 0; i < 3; i++) {
bool has_chan = b->shader->info.workgroup_size_variable ||
b->shader->info.workgroup_size[i] > 1;
/* Extract as few bits possible - we want the constant to be an inline constant
* instead of a literal.
*/
num_bits[i] = !has_chan ? 0 :
b->shader->info.workgroup_size_variable ?
10 : util_logbase2_ceil(b->shader->info.workgroup_size[i]);
}
if (s->args->local_invocation_ids_packed.used) { if (s->args->local_invocation_ids_packed.used) {
/* Thread IDs are packed in VGPR0, 10 bits per component. */ unsigned extract_bits[3];
unsigned num_bits[3]; memcpy(extract_bits, num_bits, sizeof(num_bits));
for (unsigned i = 0; i < 3; i++) { /* Thread IDs are packed in VGPR0, 10 bits per component.
bool has_chan = b->shader->info.workgroup_size_variable || * Always extract all remaining bits if later ID components are always 0, which will
b->shader->info.workgroup_size[i] > 1;
/* Extract as few bits possible - we want the constant to be an inline constant
* instead of a literal. ID.z should always extract all remaining bits, which
* will translate to a bit shift.
*/
num_bits[i] = !has_chan ? 0 :
i == 2 ? 12 :
b->shader->info.workgroup_size_variable ?
10 : util_logbase2_ceil(b->shader->info.workgroup_size[i]);
}
/* Always extract all remaining bits if later ID components are always 0, which will
* translate to a bit shift. * translate to a bit shift.
*/ */
if (!num_bits[2]) { if (num_bits[2]) {
if (num_bits[1]) extract_bits[2] = 12; /* Z > 0 */
num_bits[1] = 22; /* Y > 0, Z == 0 */ } else if (num_bits[1])
else if (num_bits[0]) extract_bits[1] = 22; /* Y > 0, Z == 0 */
num_bits[0] = 32; /* X > 0, Y == 0, Z == 0 */ else if (num_bits[0])
} extract_bits[0] = 32; /* X > 0, Y == 0, Z == 0 */
nir_def *ids_packed =
ac_nir_load_arg_upper_bound(b, s->args, s->args->local_invocation_ids_packed,
b->shader->info.workgroup_size_variable ?
0 : ((b->shader->info.workgroup_size[0] - 1) |
((b->shader->info.workgroup_size[1] - 1) << 10) |
((b->shader->info.workgroup_size[2] - 1) << 20)));
nir_def *vec[3];
for (unsigned i = 0; i < 3; i++) { for (unsigned i = 0; i < 3; i++) {
vec[i] = !num_bits[i] ? nir_imm_int(b, 0) : vec[i] = !num_bits[i] ? nir_imm_int(b, 0) :
ac_nir_unpack_arg(b, s->args, ac_nir_unpack_value(b, ids_packed, i * 10, extract_bits[i]);
s->args->local_invocation_ids_packed, i * 10,
num_bits[i]);
} }
replacement = nir_vec(b, vec, 3);
} else { } else {
replacement = nir_vec3(b, const struct ac_arg ids[] = {
ac_nir_load_arg(b, s->args, s->args->local_invocation_id_x), s->args->local_invocation_id_x,
ac_nir_load_arg(b, s->args, s->args->local_invocation_id_y), s->args->local_invocation_id_y,
ac_nir_load_arg(b, s->args, s->args->local_invocation_id_z)); s->args->local_invocation_id_z,
};
for (unsigned i = 0; i < 3; i++) {
unsigned max = b->shader->info.workgroup_size_variable ?
1023 : (b->shader->info.workgroup_size[i] - 1);
vec[i] = !num_bits[i] ? nir_imm_int(b, 0) :
ac_nir_load_arg_upper_bound(b, s->args, ids[i], max);
}
} }
replacement = nir_vec(b, vec, 3);
break; break;
}
case nir_intrinsic_load_merged_wave_info_amd: case nir_intrinsic_load_merged_wave_info_amd:
replacement = ac_nir_load_arg(b, s->args, s->args->merged_wave_info); replacement = ac_nir_load_arg(b, s->args, s->args->merged_wave_info);
break; break;
@@ -332,7 +352,7 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
replacement = ac_nir_load_arg(b, s->args, s->args->draw_id); replacement = ac_nir_load_arg(b, s->args, s->args->draw_id);
break; break;
case nir_intrinsic_load_view_index: case nir_intrinsic_load_view_index:
replacement = ac_nir_load_arg(b, s->args, s->args->view_index); replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->view_index, 1);
break; break;
case nir_intrinsic_load_invocation_id: case nir_intrinsic_load_invocation_id:
if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) { if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
@@ -341,9 +361,9 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
if (s->gfx_level >= GFX12) { if (s->gfx_level >= GFX12) {
replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_vtx_offset[0], 27, 5); replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_vtx_offset[0], 27, 5);
} else if (s->gfx_level >= GFX10) { } else if (s->gfx_level >= GFX10) {
replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_invocation_id, 0, 7); replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_invocation_id, 0, 5);
} else { } else {
replacement = ac_nir_load_arg(b, s->args, s->args->gs_invocation_id); replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->gs_invocation_id, 31);
} }
} else { } else {
unreachable("unexpected shader stage"); unreachable("unexpected shader stage");