amd: lower load_local_invocation_id in NIR
This is based on ACO. Reviewed-by: Timur Kristóf <timur.kristof@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32782>
This commit is contained in:
@@ -180,6 +180,46 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
|
||||
ac_nir_load_arg(b, s->args, s->args->frag_pos[2]),
|
||||
ac_nir_load_arg(b, s->args, s->args->frag_pos[3]));
|
||||
break;
|
||||
case nir_intrinsic_load_local_invocation_id:
|
||||
if (s->args->args[s->args->local_invocation_ids.arg_index].size == 1) {
|
||||
/* Thread IDs are packed in VGPR0, 10 bits per component. */
|
||||
unsigned num_bits[3];
|
||||
|
||||
for (unsigned i = 0; i < 3; i++) {
|
||||
bool has_chan = b->shader->info.workgroup_size_variable ||
|
||||
b->shader->info.workgroup_size[i] > 1;
|
||||
/* Extract as few bits possible - we want the constant to be an inline constant
|
||||
* instead of a literal. ID.z should always extract all remaining bits, which
|
||||
* will translate to a bit shift.
|
||||
*/
|
||||
num_bits[i] = !has_chan ? 0 :
|
||||
i == 2 ? 12 :
|
||||
b->shader->info.workgroup_size_variable ?
|
||||
10 : util_logbase2_ceil(b->shader->info.workgroup_size[i]);
|
||||
}
|
||||
|
||||
/* Always extract all remaining bits if later ID components are always 0, which will
|
||||
* translate to a bit shift.
|
||||
*/
|
||||
if (!num_bits[2]) {
|
||||
if (num_bits[1])
|
||||
num_bits[1] = 22; /* Y > 0, Z == 0 */
|
||||
else if (num_bits[0])
|
||||
num_bits[0] = 32; /* X > 0, Y == 0, Z == 0 */
|
||||
}
|
||||
|
||||
nir_def *vec[3];
|
||||
for (unsigned i = 0; i < 3; i++) {
|
||||
vec[i] = !num_bits[i] ? nir_imm_int(b, 0) :
|
||||
ac_nir_unpack_arg(b, s->args, s->args->local_invocation_ids, i * 10,
|
||||
num_bits[i]);
|
||||
}
|
||||
|
||||
replacement = nir_vec(b, vec, 3);
|
||||
} else {
|
||||
replacement = ac_nir_load_arg(b, s->args, s->args->local_invocation_ids);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
@@ -8200,43 +8200,6 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
|
||||
emit_split_vector(ctx, dst, 3);
|
||||
break;
|
||||
}
|
||||
case nir_intrinsic_load_local_invocation_id: {
|
||||
Temp dst = get_ssa_temp(ctx, &instr->def);
|
||||
if (ctx->options->gfx_level >= GFX11) {
|
||||
Temp local_ids[3];
|
||||
|
||||
/* Thread IDs are packed in VGPR0, 10 bits per component. */
|
||||
local_ids[0] = get_arg(ctx, ctx->args->local_invocation_ids);
|
||||
if (ctx->shader->info.workgroup_size[1] > 1 || ctx->shader->info.workgroup_size[2] > 1 ||
|
||||
ctx->shader->info.workgroup_size_variable) {
|
||||
unsigned size_x = ctx->shader->info.workgroup_size_variable
|
||||
? 1024
|
||||
: util_next_power_of_two(ctx->shader->info.workgroup_size[0]);
|
||||
Temp mask = bld.copy(bld.def(s1), Operand::c32(size_x - 1));
|
||||
local_ids[0] = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), mask, local_ids[0]);
|
||||
}
|
||||
|
||||
for (uint32_t i = 1; i < 3; i++) {
|
||||
if (i == 2 || (i == 1 && ctx->shader->info.workgroup_size[2] == 1 &&
|
||||
!ctx->shader->info.workgroup_size_variable)) {
|
||||
local_ids[i] =
|
||||
bld.vop2(aco_opcode::v_lshrrev_b32, bld.def(v1), Operand::c32(i * 10u),
|
||||
get_arg(ctx, ctx->args->local_invocation_ids));
|
||||
} else {
|
||||
local_ids[i] = bld.vop3(aco_opcode::v_bfe_u32, bld.def(v1),
|
||||
get_arg(ctx, ctx->args->local_invocation_ids),
|
||||
Operand::c32(i * 10u), Operand::c32(10u));
|
||||
}
|
||||
}
|
||||
|
||||
bld.pseudo(aco_opcode::p_create_vector, Definition(dst), local_ids[0], local_ids[1],
|
||||
local_ids[2]);
|
||||
} else {
|
||||
bld.copy(Definition(dst), Operand(get_arg(ctx, ctx->args->local_invocation_ids)));
|
||||
}
|
||||
emit_split_vector(ctx, dst, 3);
|
||||
break;
|
||||
}
|
||||
case nir_intrinsic_load_workgroup_id: {
|
||||
Temp dst = get_ssa_temp(ctx, &instr->def);
|
||||
if (ctx->stage.hw == AC_HW_COMPUTE_SHADER) {
|
||||
|
@@ -562,7 +562,6 @@ init_context(isel_context* ctx, nir_shader* shader)
|
||||
case nir_intrinsic_load_front_face_fsign:
|
||||
case nir_intrinsic_load_frag_shading_rate:
|
||||
case nir_intrinsic_load_sample_pos:
|
||||
case nir_intrinsic_load_local_invocation_id:
|
||||
case nir_intrinsic_load_local_invocation_index:
|
||||
case nir_intrinsic_load_subgroup_invocation:
|
||||
case nir_intrinsic_load_tess_coord:
|
||||
|
@@ -2955,22 +2955,6 @@ static bool visit_intrinsic(struct ac_nir_context *ctx, nir_intrinsic_instr *ins
|
||||
case nir_intrinsic_load_vertex_id_zero_base:
|
||||
result = ctx->abi->vertex_id_replaced ? ctx->abi->vertex_id_replaced : ctx->abi->vertex_id;
|
||||
break;
|
||||
case nir_intrinsic_load_local_invocation_id: {
|
||||
LLVMValueRef ids = ac_get_arg(&ctx->ac, ctx->args->local_invocation_ids);
|
||||
|
||||
if (LLVMGetTypeKind(LLVMTypeOf(ids)) == LLVMIntegerTypeKind) {
|
||||
/* Thread IDs are packed in VGPR0, 10 bits per component. */
|
||||
LLVMValueRef id[3];
|
||||
|
||||
for (unsigned i = 0; i < 3; i++)
|
||||
id[i] = ac_unpack_param(&ctx->ac, ids, i * 10, 10);
|
||||
|
||||
result = ac_build_gather_values(&ctx->ac, id, 3);
|
||||
} else {
|
||||
result = ids;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case nir_intrinsic_load_base_instance:
|
||||
result = ac_get_arg(&ctx->ac, ctx->args->start_instance);
|
||||
break;
|
||||
|
Reference in New Issue
Block a user