aco/gfx12: use ttmp9/ttmp7 for workgroup id

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29330>
This commit is contained in:
Rhys Perry
2024-05-20 16:57:25 +01:00
committed by Marge Bot
parent c8123b67e0
commit ef74407577
2 changed files with 60 additions and 14 deletions

View File

@@ -8341,11 +8341,8 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
case nir_intrinsic_load_workgroup_id: {
Temp dst = get_ssa_temp(ctx, &instr->def);
if (ctx->stage.hw == AC_HW_COMPUTE_SHADER) {
const struct ac_arg* ids = ctx->args->workgroup_ids;
bld.pseudo(aco_opcode::p_create_vector, Definition(dst),
ids[0].used ? Operand(get_arg(ctx, ids[0])) : Operand::zero(),
ids[1].used ? Operand(get_arg(ctx, ids[1])) : Operand::zero(),
ids[2].used ? Operand(get_arg(ctx, ids[2])) : Operand::zero());
bld.pseudo(aco_opcode::p_create_vector, Definition(dst), ctx->workgroup_id[0],
ctx->workgroup_id[1], ctx->workgroup_id[2]);
emit_split_vector(ctx, dst, 3);
} else {
isel_err(&instr->instr, "Unsupported stage for load_workgroup_id");
@@ -11148,6 +11145,9 @@ add_startpgm(struct isel_context* ctx)
def_count++;
}
if (ctx->stage.hw == AC_HW_COMPUTE_SHADER && ctx->program->gfx_level >= GFX12)
def_count += 2;
Instruction* startpgm = create_instruction(aco_opcode::p_startpgm, Format::PSEUDO, 0, def_count);
ctx->block->instructions.emplace_back(startpgm);
for (unsigned i = 0, arg = 0; i < ctx->args->arg_count; i++) {
@@ -11180,6 +11180,32 @@ add_startpgm(struct isel_context* ctx)
}
}
if (ctx->program->gfx_level >= GFX12 && ctx->stage.hw == AC_HW_COMPUTE_SHADER) {
Temp idx = ctx->program->allocateTmp(s1);
Temp idy = ctx->program->allocateTmp(s1);
startpgm->definitions[def_count - 2] = Definition(idx);
startpgm->definitions[def_count - 2].setFixed(PhysReg(108 + 9 /*ttmp9*/));
startpgm->definitions[def_count - 1] = Definition(idy);
startpgm->definitions[def_count - 1].setFixed(PhysReg(108 + 7 /*ttmp7*/));
ctx->workgroup_id[0] = Operand(idx);
if (ctx->args->workgroup_ids[2].used) {
Builder bld(ctx->program, ctx->block);
ctx->workgroup_id[1] =
bld.pseudo(aco_opcode::p_extract, bld.def(s1), bld.def(s1, scc), idy, Operand::zero(),
Operand::c32(16u), Operand::zero());
ctx->workgroup_id[2] =
bld.pseudo(aco_opcode::p_extract, bld.def(s1), bld.def(s1, scc), idy, Operand::c32(1u),
Operand::c32(16u), Operand::zero());
} else {
ctx->workgroup_id[1] = Operand(idy);
ctx->workgroup_id[2] = Operand::zero();
}
} else if (ctx->stage.hw == AC_HW_COMPUTE_SHADER) {
const struct ac_arg* ids = ctx->args->workgroup_ids;
for (unsigned i = 0; i < 3; i++)
ctx->workgroup_id[i] = ids[i].used ? Operand(get_arg(ctx, ids[i])) : Operand::zero();
}
/* epilog has no scratch */
if (ctx->args->scratch_offset.used) {
if (ctx->program->gfx_level < GFX9) {
@@ -12289,10 +12315,18 @@ select_rt_prolog(Program* program, ac_shader_config* config,
PhysReg in_sbt_desc = get_arg_reg(in_args, in_args->rt.sbt_descriptors);
PhysReg in_launch_size_addr = get_arg_reg(in_args, in_args->rt.launch_size_addr);
PhysReg in_stack_base = get_arg_reg(in_args, in_args->rt.dynamic_callable_stack_base);
PhysReg in_wg_id_x = get_arg_reg(in_args, in_args->workgroup_ids[0]);
PhysReg in_wg_id_y = get_arg_reg(in_args, in_args->workgroup_ids[1]);
PhysReg in_wg_id_z = get_arg_reg(in_args, in_args->workgroup_ids[2]);
PhysReg in_wg_id_x;
PhysReg in_wg_id_y;
PhysReg in_wg_id_z;
PhysReg in_scratch_offset;
if (options->gfx_level < GFX12) {
in_wg_id_x = get_arg_reg(in_args, in_args->workgroup_ids[0]);
in_wg_id_y = get_arg_reg(in_args, in_args->workgroup_ids[1]);
in_wg_id_z = get_arg_reg(in_args, in_args->workgroup_ids[2]);
} else {
in_wg_id_x = PhysReg(108 + 9 /*ttmp9*/);
in_wg_id_y = PhysReg(108 + 7 /*ttmp7*/);
}
if (options->gfx_level < GFX11)
in_scratch_offset = get_arg_reg(in_args, in_args->scratch_offset);
PhysReg in_local_ids[2] = {
@@ -12330,6 +12364,8 @@ select_rt_prolog(Program* program, ac_shader_config* config,
num_sgprs += 2;
PhysReg tmp_ring_offsets = PhysReg{num_sgprs};
num_sgprs += 2;
PhysReg tmp_wg_id_x_times_size = PhysReg{num_sgprs};
num_sgprs++;
PhysReg tmp_invocation_idx = PhysReg{256 + num_vgprs++};
@@ -12379,9 +12415,18 @@ select_rt_prolog(Program* program, ac_shader_config* config,
Operand(in_local_ids[0], v1));
}
/* Do this backwards to reduce some RAW hazards on GFX11+ */
bld.vop1(aco_opcode::v_mov_b32, Definition(out_launch_ids[2], v1), Operand(in_wg_id_z, s1));
bld.vop3(aco_opcode::v_mad_u32_u24, Definition(out_launch_ids[1], v1), Operand(in_wg_id_y, s1),
Operand::c32(program->workgroup_size == 32 ? 4 : 8), Operand(in_local_ids[1], v1));
if (options->gfx_level >= GFX12) {
bld.vop2_e64(aco_opcode::v_lshrrev_b32, Definition(out_launch_ids[2], v1), Operand::c32(16),
Operand(in_wg_id_y, s1));
bld.vop3(aco_opcode::v_mad_u32_u16, Definition(out_launch_ids[1], v1),
Operand(in_wg_id_y, s1), Operand::c32(program->workgroup_size == 32 ? 4 : 8),
Operand(in_local_ids[1], v1));
} else {
bld.vop1(aco_opcode::v_mov_b32, Definition(out_launch_ids[2], v1), Operand(in_wg_id_z, s1));
bld.vop3(aco_opcode::v_mad_u32_u24, Definition(out_launch_ids[1], v1),
Operand(in_wg_id_y, s1), Operand::c32(program->workgroup_size == 32 ? 4 : 8),
Operand(in_local_ids[1], v1));
}
bld.vop3(aco_opcode::v_mad_u32_u24, Definition(out_launch_ids[0], v1), Operand(in_wg_id_x, s1),
Operand::c32(8), Operand(in_local_ids[0], v1));
@@ -12407,14 +12452,14 @@ select_rt_prolog(Program* program, ac_shader_config* config,
/* For 1D dispatches converted into 2D ones, we need to fix up the launch IDs.
* Calculating the 1D launch ID is: id = local_invocation_index + (wg_id.x * wg_size).
* in_wg_id_x now holds wg_id.x * wg_size.
* tmp_wg_id_x_times_size now holds wg_id.x * wg_size.
*/
bld.sop2(aco_opcode::s_lshl_b32, Definition(in_wg_id_x, s1), Definition(scc, s1),
bld.sop2(aco_opcode::s_lshl_b32, Definition(tmp_wg_id_x_times_size, s1), Definition(scc, s1),
Operand(in_wg_id_x, s1), Operand::c32(program->workgroup_size == 32 ? 5 : 6));
/* Calculate and add local_invocation_index */
bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, Definition(tmp_invocation_idx, v1), Operand::c32(-1u),
Operand(in_wg_id_x, s1));
Operand(tmp_wg_id_x_times_size, s1));
if (program->wave_size == 64) {
if (program->gfx_level <= GFX7)
bld.vop2(aco_opcode::v_mbcnt_hi_u32_b32, Definition(tmp_invocation_idx, v1),

View File

@@ -72,6 +72,7 @@ struct isel_context {
nir_unsigned_upper_bound_config ub_config;
Temp arg_temps[AC_MAX_ARGS];
Operand workgroup_id[3];
/* tessellation information */
uint64_t tcs_temp_only_inputs;