ac/nir: set upper ranges for range analysis while lowering system values

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32782>
This commit is contained in:
Marek Olšák
2024-12-29 21:12:40 -05:00
committed by Marge Bot
parent 0d5b03f2b9
commit 5dd9171765

View File

@@ -75,11 +75,9 @@ ac_nir_store_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct ac
nir_store_vector_arg_amd(b, val, .base = arg.arg_index);
}
nir_def *
ac_nir_unpack_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct ac_arg arg,
unsigned rshift, unsigned bitwidth)
static nir_def *
ac_nir_unpack_value(nir_builder *b, nir_def *value, unsigned rshift, unsigned bitwidth)
{
nir_def *value = ac_nir_load_arg(b, ac_args, arg);
if (rshift == 0 && bitwidth == 32)
return value;
else if (rshift == 0)
@@ -90,6 +88,14 @@ ac_nir_unpack_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct a
return nir_ubfe_imm(b, value, rshift, bitwidth);
}
nir_def *
ac_nir_unpack_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct ac_arg arg,
unsigned rshift, unsigned bitwidth)
{
nir_def *value = ac_nir_load_arg(b, ac_args, arg);
return ac_nir_unpack_value(b, value, rshift, bitwidth);
}
static bool
is_sin_cos(const nir_instr *instr, UNUSED const void *_)
{
@@ -232,50 +238,64 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
ac_nir_load_arg(b, s->args, s->args->frag_pos[2]),
ac_nir_load_arg(b, s->args, s->args->frag_pos[3]));
break;
case nir_intrinsic_load_local_invocation_id:
case nir_intrinsic_load_local_invocation_id: {
unsigned num_bits[3];
nir_def *vec[3];
for (unsigned i = 0; i < 3; i++) {
bool has_chan = b->shader->info.workgroup_size_variable ||
b->shader->info.workgroup_size[i] > 1;
/* Extract as few bits possible - we want the constant to be an inline constant
* instead of a literal.
*/
num_bits[i] = !has_chan ? 0 :
b->shader->info.workgroup_size_variable ?
10 : util_logbase2_ceil(b->shader->info.workgroup_size[i]);
}
if (s->args->local_invocation_ids_packed.used) {
/* Thread IDs are packed in VGPR0, 10 bits per component. */
unsigned num_bits[3];
unsigned extract_bits[3];
memcpy(extract_bits, num_bits, sizeof(num_bits));
for (unsigned i = 0; i < 3; i++) {
bool has_chan = b->shader->info.workgroup_size_variable ||
b->shader->info.workgroup_size[i] > 1;
/* Extract as few bits possible - we want the constant to be an inline constant
* instead of a literal. ID.z should always extract all remaining bits, which
* will translate to a bit shift.
*/
num_bits[i] = !has_chan ? 0 :
i == 2 ? 12 :
b->shader->info.workgroup_size_variable ?
10 : util_logbase2_ceil(b->shader->info.workgroup_size[i]);
}
/* Always extract all remaining bits if later ID components are always 0, which will
/* Thread IDs are packed in VGPR0, 10 bits per component.
* Always extract all remaining bits if later ID components are always 0, which will
* translate to a bit shift.
*/
if (!num_bits[2]) {
if (num_bits[1])
num_bits[1] = 22; /* Y > 0, Z == 0 */
else if (num_bits[0])
num_bits[0] = 32; /* X > 0, Y == 0, Z == 0 */
}
if (num_bits[2]) {
extract_bits[2] = 12; /* Z > 0 */
} else if (num_bits[1])
extract_bits[1] = 22; /* Y > 0, Z == 0 */
else if (num_bits[0])
extract_bits[0] = 32; /* X > 0, Y == 0, Z == 0 */
nir_def *ids_packed =
ac_nir_load_arg_upper_bound(b, s->args, s->args->local_invocation_ids_packed,
b->shader->info.workgroup_size_variable ?
0 : ((b->shader->info.workgroup_size[0] - 1) |
((b->shader->info.workgroup_size[1] - 1) << 10) |
((b->shader->info.workgroup_size[2] - 1) << 20)));
nir_def *vec[3];
for (unsigned i = 0; i < 3; i++) {
vec[i] = !num_bits[i] ? nir_imm_int(b, 0) :
ac_nir_unpack_arg(b, s->args,
s->args->local_invocation_ids_packed, i * 10,
num_bits[i]);
ac_nir_unpack_value(b, ids_packed, i * 10, extract_bits[i]);
}
replacement = nir_vec(b, vec, 3);
} else {
replacement = nir_vec3(b,
ac_nir_load_arg(b, s->args, s->args->local_invocation_id_x),
ac_nir_load_arg(b, s->args, s->args->local_invocation_id_y),
ac_nir_load_arg(b, s->args, s->args->local_invocation_id_z));
const struct ac_arg ids[] = {
s->args->local_invocation_id_x,
s->args->local_invocation_id_y,
s->args->local_invocation_id_z,
};
for (unsigned i = 0; i < 3; i++) {
unsigned max = b->shader->info.workgroup_size_variable ?
1023 : (b->shader->info.workgroup_size[i] - 1);
vec[i] = !num_bits[i] ? nir_imm_int(b, 0) :
ac_nir_load_arg_upper_bound(b, s->args, ids[i], max);
}
}
replacement = nir_vec(b, vec, 3);
break;
}
case nir_intrinsic_load_merged_wave_info_amd:
replacement = ac_nir_load_arg(b, s->args, s->args->merged_wave_info);
break;
@@ -332,7 +352,7 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
replacement = ac_nir_load_arg(b, s->args, s->args->draw_id);
break;
case nir_intrinsic_load_view_index:
replacement = ac_nir_load_arg(b, s->args, s->args->view_index);
replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->view_index, 1);
break;
case nir_intrinsic_load_invocation_id:
if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
@@ -341,9 +361,9 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
if (s->gfx_level >= GFX12) {
replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_vtx_offset[0], 27, 5);
} else if (s->gfx_level >= GFX10) {
replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_invocation_id, 0, 7);
replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_invocation_id, 0, 5);
} else {
replacement = ac_nir_load_arg(b, s->args, s->args->gs_invocation_id);
replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->gs_invocation_id, 31);
}
} else {
unreachable("unexpected shader stage");