nir: add cheap shortcut for wg id to wg idx lowering

... for platforms where integer division is expensive

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22334>
This commit is contained in:
Marcin Ślusarz
2023-03-31 12:19:43 +02:00
committed by Marge Bot
parent 7ec1ef75d3
commit 3d7513ee8e
2 changed files with 50 additions and 11 deletions

View File

@@ -5408,6 +5408,10 @@ typedef struct nir_lower_compute_system_values_options {
bool lower_local_invocation_index:1;
bool lower_cs_local_id_to_index:1;
bool lower_workgroup_id_to_index:1;
/* At shader execution time, check if WorkGroupId should be 1D
* and compute it quickly. Fall back to slow computation if not.
*/
bool shortcut_1d_workgroup_id:1;
uint16_t num_workgroups[3]; /* Compile-time-known dispatch sizes, or 0 if unknown. */
} nir_lower_compute_system_values_options;

View File

@@ -359,9 +359,9 @@ nir_lower_system_values(nir_shader *shader)
}
static nir_ssa_def *
lower_id_to_index_no_umod(nir_builder *b, nir_ssa_def *index,
nir_ssa_def *size, unsigned bit_size,
const uint16_t *size_imm)
id_to_index_no_umod_slow(nir_builder *b, nir_ssa_def *index,
nir_ssa_def *size_x, nir_ssa_def *size_y,
unsigned bit_size)
{
/* We lower ID to Index with the following formula:
*
@@ -374,6 +374,22 @@ lower_id_to_index_no_umod(nir_builder *b, nir_ssa_def *index,
* not compile time known or not a power of two.
*/
nir_ssa_def *size_x_y = nir_imul(b, size_x, size_y);
nir_ssa_def *id_z = nir_udiv(b, index, size_x_y);
nir_ssa_def *z_portion = nir_imul(b, id_z, size_x_y);
nir_ssa_def *id_y = nir_udiv(b, nir_isub(b, index, z_portion), size_x);
nir_ssa_def *y_portion = nir_imul(b, id_y, size_x);
nir_ssa_def *id_x = nir_isub(b, index, nir_iadd(b, z_portion, y_portion));
return nir_u2uN(b, nir_vec3(b, id_x, id_y, id_z), bit_size);
}
static nir_ssa_def *
lower_id_to_index_no_umod(nir_builder *b, nir_ssa_def *index,
nir_ssa_def *size, unsigned bit_size,
const uint16_t *size_imm,
bool shortcut_1d)
{
nir_ssa_def *size_x, *size_y;
if (size_imm[0] > 0)
@@ -386,15 +402,33 @@ lower_id_to_index_no_umod(nir_builder *b, nir_ssa_def *index,
else
size_y = nir_channel(b, size, 1);
nir_ssa_def *size_x_y = nir_imul(b, size_x, size_y);
if (shortcut_1d) {
/* if size.y + size.z == 2 (which means that both y and z are 1)
* id = vec3(index, 0, 0)
* else
* id = id_to_index_no_umod_slow
*/
nir_ssa_def *id_z = nir_udiv(b, index, size_x_y);
nir_ssa_def *z_portion = nir_imul(b, id_z, size_x_y);
nir_ssa_def *id_y = nir_udiv(b, nir_isub(b, index, z_portion), size_x);
nir_ssa_def *y_portion = nir_imul(b, id_y, size_x);
nir_ssa_def *id_x = nir_isub(b, index, nir_iadd(b, z_portion, y_portion));
nir_ssa_def *size_z = nir_channel(b, size, 2);
nir_ssa_def *cond = nir_ieq(b, nir_iadd(b, size_y, size_z), nir_imm_int(b, 2));
return nir_u2uN(b, nir_vec3(b, id_x, id_y, id_z), bit_size);
nir_ssa_def *val1, *val2;
nir_if *if_opt = nir_push_if(b, cond);
if_opt->control = nir_selection_control_dont_flatten;
{
nir_ssa_def *zero = nir_imm_int(b, 0);
val1 = nir_u2uN(b, nir_vec3(b, index, zero, zero), bit_size);
}
nir_push_else(b, if_opt);
{
val2 = id_to_index_no_umod_slow(b, index, size_x, size_y, bit_size);
}
nir_pop_if(b, if_opt);
return nir_if_phi(b, val1, val2);
} else {
return id_to_index_no_umod_slow(b, index, size_x, size_y, bit_size);
}
}
@@ -694,7 +728,8 @@ lower_compute_system_value_instr(nir_builder *b,
return lower_id_to_index_no_umod(b, wg_idx,
nir_load_num_workgroups(b, bit_size),
bit_size,
options->num_workgroups);
options->num_workgroups,
options->shortcut_1d_workgroup_id);
}
return NULL;