intel/brw: Reorganize lowering of LocalID/Index to handle Mesh/Task

Reorganize the code to make clearer all the lowering cases:

(a) Single invocation workgroup.  Index and IDs are all zero.
(b) Local ID provided by hardware.
(c) Local Index provided by the hardware.  Depending on the case this
    might not be the final local index, e.g. heuristics for tile.
(d) Neither provided by the hardware.

Case (c) is new and supported by Mesh/Task shaders.  At the moment the
nir_lower_compute_system_values handle lowering of LocalID for
Task/Mesh, but a later patch will flip that on ANV.

This will make the Task/Mesh use the same lowering as Compute shaders.

Reviewed-by: Kenneth Graunke <kenneth@whitecape.org>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29828>
This commit is contained in:
Caio Oliveira
2024-06-20 19:49:27 -07:00
committed by Marge Bot
parent f0b0a71a9b
commit d89bfb1ff7

View File

@@ -27,23 +27,80 @@
struct lower_intrinsics_state {
nir_shader *nir;
nir_function_impl *impl;
enum gl_derivative_group derivative_group;
bool progress;
bool hw_generated_local_id;
nir_builder builder;
/* Per-block cached values. */
bool computed;
nir_def *hw_index;
nir_def *local_index;
nir_def *local_id;
};
static void
compute_local_index_id(nir_builder *b,
nir_shader *nir,
nir_def **local_index,
nir_def **local_id)
compute_local_index_id(struct lower_intrinsics_state *state, nir_intrinsic_instr *current)
{
nir_def *subgroup_id = nir_load_subgroup_id(b);
assert(!state->computed);
state->hw_index = NULL;
state->local_index = NULL;
state->local_id = NULL;
state->computed = true;
nir_shader *nir = state->nir;
nir_builder *b = &state->builder;
if (!nir->info.workgroup_size_variable) {
/* Don't calculate anything for a single invocation workgroup. */
const uint16_t *ws = nir->info.workgroup_size;
if (ws[0] * ws[1] * ws[2] == 1) {
nir_def *zero = nir_imm_int(b, 0);
state->local_index = zero;
state->local_id = nir_replicate(b, zero, 3);
return;
}
if (state->hw_generated_local_id) {
assert(state->derivative_group != DERIVATIVE_GROUP_QUADS);
nir_def *local_id_vec = nir_load_local_invocation_id(b);
nir_def *local_id[3] = { nir_channel(b, local_id_vec, 0),
nir_channel(b, local_id_vec, 1),
nir_channel(b, local_id_vec, 2) };
nir_def *size_x = nir_imm_int(b, nir->info.workgroup_size[0]);
nir_def *size_y = nir_imm_int(b, nir->info.workgroup_size[1]);
nir_def *local_index = nir_imul(b, local_id[2], nir_imul(b, size_x, size_y));
local_index = nir_iadd(b, local_index, nir_imul(b, local_id[1], size_x));
local_index = nir_iadd(b, local_index, local_id[0]);
state->local_index = local_index;
state->local_id = NULL;
return;
}
}
/* Linear index. Depending on the heuristic or the derivative group, will
* need to be processed to become the actual local_index.
*/
nir_def *linear;
if (nir->info.stage == MESA_SHADER_MESH || nir->info.stage == MESA_SHADER_TASK) {
/* Thread payload provides a linear index, keep track of it
* so it doesn't get removed.
*/
state->hw_index =
current->intrinsic == nir_intrinsic_load_local_invocation_index ?
&current->def : nir_load_local_invocation_index(b);
linear = state->hw_index;
} else {
nir_def *subgroup_id = nir_load_subgroup_id(b);
nir_def *thread_local_id =
nir_imul(b, subgroup_id, nir_load_simd_width_intel(b));
nir_def *channel = nir_load_subgroup_invocation(b);
nir_def *linear = nir_iadd(b, channel, thread_local_id);
linear = nir_iadd(b, channel, thread_local_id);
}
nir_def *size_x;
nir_def *size_y;
@@ -75,7 +132,7 @@ compute_local_index_id(nir_builder *b,
*/
nir_def *id_x, *id_y, *id_z;
switch (nir->info.cs.derivative_group) {
switch (state->derivative_group) {
case DERIVATIVE_GROUP_NONE:
if (nir->info.num_images == 0 &&
nir->info.num_textures == 0) {
@@ -85,7 +142,7 @@ compute_local_index_id(nir_builder *b,
*/
id_x = nir_umod(b, linear, size_x);
id_y = nir_umod(b, nir_udiv(b, linear, size_x), size_y);
*local_index = linear;
state->local_index = linear;
} else if (!nir->info.workgroup_size_variable &&
nir->info.workgroup_size[1] % 4 == 0) {
/* 1x4 block X-major lid order. Same as X-major except increments in
@@ -116,9 +173,9 @@ compute_local_index_id(nir_builder *b,
}
id_z = nir_udiv(b, linear, size_xy);
*local_id = nir_vec3(b, id_x, id_y, id_z);
if (!*local_index) {
*local_index = nir_iadd(b, nir_iadd(b, id_x,
state->local_id = nir_vec3(b, id_x, id_y, id_z);
if (!state->local_index) {
state->local_index = nir_iadd(b, nir_iadd(b, id_x,
nir_imul(b, id_y, size_x)),
nir_imul(b, id_z, size_xy));
}
@@ -130,8 +187,8 @@ compute_local_index_id(nir_builder *b,
id_x = nir_umod(b, linear, size_x);
id_y = nir_umod(b, nir_udiv(b, linear, size_x), size_y);
id_z = nir_udiv(b, linear, size_xy);
*local_id = nir_vec3(b, id_x, id_y, id_z);
*local_index = linear;
state->local_id = nir_vec3(b, id_x, id_y, id_z);
state->local_index = linear;
break;
case DERIVATIVE_GROUP_QUADS: {
/* For quads, first we figure out the 2x2 grid the invocation
@@ -157,10 +214,10 @@ compute_local_index_id(nir_builder *b,
nir_ishl(b, y_row_pairs, one),
nir_iand(b, nir_ishr(b, row_pair_id, one), one));
*local_id = nir_vec3(b, x,
state->local_id = nir_vec3(b, x,
nir_umod(b, y, size_y),
nir_udiv(b, y, size_y));
*local_index = nir_iadd(b, x, nir_imul(b, y, size_x));
state->local_index = nir_iadd(b, x, nir_imul(b, y, size_x));
break;
}
default:
@@ -176,9 +233,8 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
nir_builder *b = &state->builder;
nir_shader *nir = state->nir;
/* Reuse calculated values inside the block. */
nir_def *local_index = NULL;
nir_def *local_id = NULL;
/* Reset per-block definitions. */
state->computed = false;
nir_foreach_instr_safe(instr, block) {
if (instr->type != nir_instr_type_intrinsic)
@@ -190,56 +246,30 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
nir_def *sysval;
switch (intrinsic->intrinsic) {
case nir_intrinsic_load_local_invocation_id:
if (state->hw_generated_local_id)
continue;
case nir_intrinsic_load_local_invocation_id: {
if (!state->computed)
compute_local_index_id(state, intrinsic);
FALLTHROUGH;
case nir_intrinsic_load_local_invocation_index: {
if (!local_index && !nir->info.workgroup_size_variable) {
const uint16_t *ws = nir->info.workgroup_size;
if (ws[0] * ws[1] * ws[2] == 1) {
nir_def *zero = nir_imm_int(b, 0);
local_index = zero;
local_id = nir_replicate(b, zero, 3);
}
}
if (!local_index) {
if (nir->info.stage == MESA_SHADER_TASK ||
nir->info.stage == MESA_SHADER_MESH) {
/* Will be lowered by nir_emit_task_mesh_intrinsic() using
* information from the payload.
*/
if (!state->local_id) {
/* Will be lowered later by the backend. */
assert(state->hw_generated_local_id);
continue;
}
if (state->hw_generated_local_id) {
nir_def *local_id_vec = nir_load_local_invocation_id(b);
nir_def *local_id[3] = { nir_channel(b, local_id_vec, 0),
nir_channel(b, local_id_vec, 1),
nir_channel(b, local_id_vec, 2) };
nir_def *size_x = nir_imm_int(b, nir->info.workgroup_size[0]);
nir_def *size_y = nir_imm_int(b, nir->info.workgroup_size[1]);
sysval = nir_imul(b, local_id[2], nir_imul(b, size_x, size_y));
sysval = nir_iadd(b, sysval, nir_imul(b, local_id[1], size_x));
sysval = nir_iadd(b, sysval, local_id[0]);
local_index = sysval;
sysval = state->local_id;
break;
}
/* First time we are using those, so let's calculate them. */
assert(!local_id);
compute_local_index_id(b, nir, &local_index, &local_id);
}
case nir_intrinsic_load_local_invocation_index: {
if (!state->computed)
compute_local_index_id(state, intrinsic);
assert(local_id);
assert(local_index);
if (intrinsic->intrinsic == nir_intrinsic_load_local_invocation_id)
sysval = local_id;
else
sysval = local_index;
/* Will be lowered later by the backend. */
if (&intrinsic->def == state->hw_index)
continue;
assert(state->local_index);
sysval = state->local_index;
break;
}
@@ -303,15 +333,16 @@ brw_nir_lower_cs_intrinsics(nir_shader *nir,
struct lower_intrinsics_state state = {
.nir = nir,
.hw_generated_local_id = false,
.derivative_group = gl_shader_stage_is_compute(nir->info.stage) ?
nir->info.cs.derivative_group : DERIVATIVE_GROUP_NONE,
};
/* Constraints from NV_compute_shader_derivatives. */
if (gl_shader_stage_is_compute(nir->info.stage) &&
!nir->info.workgroup_size_variable) {
if (nir->info.cs.derivative_group == DERIVATIVE_GROUP_QUADS) {
if (!nir->info.workgroup_size_variable) {
if (state.derivative_group == DERIVATIVE_GROUP_QUADS) {
assert(nir->info.workgroup_size[0] % 2 == 0);
assert(nir->info.workgroup_size[1] % 2 == 0);
} else if (nir->info.cs.derivative_group == DERIVATIVE_GROUP_LINEAR) {
} else if (state.derivative_group == DERIVATIVE_GROUP_LINEAR) {
ASSERTED unsigned workgroup_size =
nir->info.workgroup_size[0] *
nir->info.workgroup_size[1] *
@@ -322,7 +353,7 @@ brw_nir_lower_cs_intrinsics(nir_shader *nir,
if (devinfo->verx10 >= 125 && prog_data &&
nir->info.stage == MESA_SHADER_COMPUTE &&
nir->info.cs.derivative_group != DERIVATIVE_GROUP_QUADS &&
state.derivative_group != DERIVATIVE_GROUP_QUADS &&
!nir->info.workgroup_size_variable &&
util_is_power_of_two_nonzero(nir->info.workgroup_size[0]) &&
util_is_power_of_two_nonzero(nir->info.workgroup_size[1])) {