aco: fix wmma raw hazard

No fossil-db changes.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29912>
This commit is contained in:
Rhys Perry
2024-06-25 14:39:00 +01:00
committed by Marge Bot
parent a6eb5c9caa
commit 17758f0a02
2 changed files with 126 additions and 3 deletions

View File

@@ -266,6 +266,9 @@ struct NOP_ctx_gfx11 {
std::bitset<128> sgpr_read_by_valu_as_lanemask;
std::bitset<128> sgpr_read_by_valu_as_lanemask_then_wr_by_salu;
/* WMMAHazards */
std::bitset<256> vgpr_written_by_wmma;
void join(const NOP_ctx_gfx11& other)
{
has_Vcmpx |= other.has_Vcmpx;
@@ -279,6 +282,7 @@ struct NOP_ctx_gfx11 {
sgpr_read_by_valu_as_lanemask |= other.sgpr_read_by_valu_as_lanemask;
sgpr_read_by_valu_as_lanemask_then_wr_by_salu |=
other.sgpr_read_by_valu_as_lanemask_then_wr_by_salu;
vgpr_written_by_wmma |= other.vgpr_written_by_wmma;
}
bool operator==(const NOP_ctx_gfx11& other)
@@ -293,7 +297,8 @@ struct NOP_ctx_gfx11 {
trans_since_wr_by_trans == other.trans_since_wr_by_trans &&
sgpr_read_by_valu_as_lanemask == other.sgpr_read_by_valu_as_lanemask &&
sgpr_read_by_valu_as_lanemask_then_wr_by_salu ==
other.sgpr_read_by_valu_as_lanemask_then_wr_by_salu;
other.sgpr_read_by_valu_as_lanemask_then_wr_by_salu &&
vgpr_written_by_wmma == other.vgpr_written_by_wmma;
}
};
@@ -1118,6 +1123,18 @@ fill_vgpr_bitset(std::bitset<256>& set, PhysReg reg, unsigned bytes)
set.set(reg.reg() - 256 + i);
}
bool
test_vgpr_bitset(std::bitset<256>& set, Operand op)
{
if (op.physReg().reg() < 256)
return false;
for (unsigned i = 0; i < op.size(); i++) {
if (set[op.physReg().reg() - 256 + i])
return true;
}
return false;
}
/* GFX11 */
unsigned
parse_vdst_wait(aco_ptr<Instruction>& instr)
@@ -1568,6 +1585,24 @@ handle_instruction_gfx11(State& state, NOP_ctx_gfx11& ctx, aco_ptr<Instruction>&
ctx.vgpr_used_by_ds.reset();
}
}
/* WMMA Hazards */
if (instr_info.classes[(int)instr->opcode] == instr_class::wmma) {
assert(instr->operands.back().regClass() == instr->definitions[0].regClass());
bool is_swmma = instr->operands.size() == 4;
if (test_vgpr_bitset(ctx.vgpr_written_by_wmma, instr->operands[0]) ||
test_vgpr_bitset(ctx.vgpr_written_by_wmma, instr->operands[1]) ||
(is_swmma && test_vgpr_bitset(ctx.vgpr_written_by_wmma, instr->operands[2]))) {
bld.vop1(aco_opcode::v_nop);
}
ctx.vgpr_written_by_wmma.reset();
fill_vgpr_bitset(ctx.vgpr_written_by_wmma, instr->definitions[0].physReg(),
instr->definitions[0].bytes());
} else if (instr->isVALU()) {
ctx.vgpr_written_by_wmma.reset();
}
}
bool
@@ -1619,9 +1654,10 @@ resolve_all_gfx11(State& state, NOP_ctx_gfx11& ctx,
ctx.trans_since_wr_by_trans.reset();
}
/* VcmpxPermlaneHazard */
if (ctx.has_Vcmpx) {
/* VcmpxPermlaneHazard/WMMAHazards */
if (ctx.has_Vcmpx || ctx.vgpr_written_by_wmma.any()) {
ctx.has_Vcmpx = false;
ctx.vgpr_written_by_wmma.reset();
bld.vop1(aco_opcode::v_nop);
}

View File

@@ -1129,6 +1129,79 @@ BEGIN_TEST(insert_nops.valu_mask_write)
finish_insert_nops_test();
END_TEST
BEGIN_TEST(insert_nops.wmma_raw)
if (!setup_cs(NULL, GFX11))
return;
/* Basic case. */
//>> p_unit_test 0
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[0-7].xx, %_:v[8-15].xx, %_:v[20-23].xx
//! v_nop
//! v4: %_:v[48-51] = v_wmma_f16_16x16x16_f16 %_:v[24-31].xx, %_:v[16-23].xx, %_:v[48-51].xx
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(0));
Operand A(PhysReg(256 + 0), v8);
Operand B(PhysReg(256 + 8), v8);
Operand C(PhysReg(256 + 20), v4);
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
0);
A.setFixed(PhysReg(256 + 24));
B.setFixed(PhysReg(256 + 16));
C.setFixed(PhysReg(256 + 48));
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
0);
/* Mitigation. */
//! p_unit_test 1
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[0-7].xx, %_:v[8-15].xx, %_:v[20-23].xx
//! v1: %_:v[56] = v_rcp_f32 0
//! v4: %_:v[48-51] = v_wmma_f16_16x16x16_f16 %_:v[24-31].xx, %_:v[16-23].xx, %_:v[48-51].xx
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(1));
A.setFixed(PhysReg(256 + 0));
B.setFixed(PhysReg(256 + 8));
C.setFixed(PhysReg(256 + 20));
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
0);
bld.vop1(aco_opcode::v_rcp_f32, Definition(PhysReg(256 + 56), v1), Operand::zero());
A.setFixed(PhysReg(256 + 24));
B.setFixed(PhysReg(256 + 16));
C.setFixed(PhysReg(256 + 48));
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
0);
/* No hazard. */
//>> p_unit_test 2
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[0-7].xx, %_:v[8-15].xx, %_:v[20-23].xx
//! v4: %_:v[48-51] = v_wmma_f16_16x16x16_f16 %_:v[24-31].xx, %_:v[32-39].xx, %_:v[48-51].xx
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(2));
A.setFixed(PhysReg(256 + 0));
B.setFixed(PhysReg(256 + 8));
C.setFixed(PhysReg(256 + 20));
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
0);
A.setFixed(PhysReg(256 + 24));
B.setFixed(PhysReg(256 + 32));
C.setFixed(PhysReg(256 + 48));
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
0);
//>> p_unit_test 3
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[0-7].xx, %_:v[8-15].xx, %_:v[20-23].xx
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[24-31].xx, %_:v[32-39].xx, %_:v[20-23].xx
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(3));
A.setFixed(PhysReg(256 + 0));
B.setFixed(PhysReg(256 + 8));
C.setFixed(PhysReg(256 + 20));
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
0);
A.setFixed(PhysReg(256 + 24));
B.setFixed(PhysReg(256 + 32));
C.setFixed(PhysReg(256 + 20));
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
0);
finish_insert_nops_test();
END_TEST
BEGIN_TEST(insert_nops.setpc_gfx6)
if (!setup_cs(NULL, GFX6))
return;
@@ -1447,6 +1520,20 @@ BEGIN_TEST(insert_nops.setpc_gfx11)
bld.ds(aco_opcode::ds_read_b32, Definition(PhysReg(256), v1), Operand(PhysReg(256), v1));
bld.sop1(aco_opcode::s_setpc_b64, Operand::zero(8));
/* WMMA Hazards */
//! p_unit_test 7
//! v4: %0:v[20-23] = v_wmma_f16_16x16x16_f16 %0:v[0-7].xx, %0:v[8-15].xx, %0:v[20-23].xx
//! v_nop
//! s_waitcnt_depctr va_vdst(0)
//! s_setpc_b64 0
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(7));
Operand A(PhysReg(256 + 0), v8);
Operand B(PhysReg(256 + 8), v8);
Operand C(PhysReg(256 + 20), v4);
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
0);
bld.sop1(aco_opcode::s_setpc_b64, Operand::zero(8));
finish_insert_nops_test(true);
END_TEST