aco: fix wmma raw hazard
No fossil-db changes. Signed-off-by: Rhys Perry <pendingchaos02@gmail.com> Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29912>
This commit is contained in:
@@ -266,6 +266,9 @@ struct NOP_ctx_gfx11 {
|
||||
std::bitset<128> sgpr_read_by_valu_as_lanemask;
|
||||
std::bitset<128> sgpr_read_by_valu_as_lanemask_then_wr_by_salu;
|
||||
|
||||
/* WMMAHazards */
|
||||
std::bitset<256> vgpr_written_by_wmma;
|
||||
|
||||
void join(const NOP_ctx_gfx11& other)
|
||||
{
|
||||
has_Vcmpx |= other.has_Vcmpx;
|
||||
@@ -279,6 +282,7 @@ struct NOP_ctx_gfx11 {
|
||||
sgpr_read_by_valu_as_lanemask |= other.sgpr_read_by_valu_as_lanemask;
|
||||
sgpr_read_by_valu_as_lanemask_then_wr_by_salu |=
|
||||
other.sgpr_read_by_valu_as_lanemask_then_wr_by_salu;
|
||||
vgpr_written_by_wmma |= other.vgpr_written_by_wmma;
|
||||
}
|
||||
|
||||
bool operator==(const NOP_ctx_gfx11& other)
|
||||
@@ -293,7 +297,8 @@ struct NOP_ctx_gfx11 {
|
||||
trans_since_wr_by_trans == other.trans_since_wr_by_trans &&
|
||||
sgpr_read_by_valu_as_lanemask == other.sgpr_read_by_valu_as_lanemask &&
|
||||
sgpr_read_by_valu_as_lanemask_then_wr_by_salu ==
|
||||
other.sgpr_read_by_valu_as_lanemask_then_wr_by_salu;
|
||||
other.sgpr_read_by_valu_as_lanemask_then_wr_by_salu &&
|
||||
vgpr_written_by_wmma == other.vgpr_written_by_wmma;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1118,6 +1123,18 @@ fill_vgpr_bitset(std::bitset<256>& set, PhysReg reg, unsigned bytes)
|
||||
set.set(reg.reg() - 256 + i);
|
||||
}
|
||||
|
||||
bool
|
||||
test_vgpr_bitset(std::bitset<256>& set, Operand op)
|
||||
{
|
||||
if (op.physReg().reg() < 256)
|
||||
return false;
|
||||
for (unsigned i = 0; i < op.size(); i++) {
|
||||
if (set[op.physReg().reg() - 256 + i])
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/* GFX11 */
|
||||
unsigned
|
||||
parse_vdst_wait(aco_ptr<Instruction>& instr)
|
||||
@@ -1568,6 +1585,24 @@ handle_instruction_gfx11(State& state, NOP_ctx_gfx11& ctx, aco_ptr<Instruction>&
|
||||
ctx.vgpr_used_by_ds.reset();
|
||||
}
|
||||
}
|
||||
|
||||
/* WMMA Hazards */
|
||||
if (instr_info.classes[(int)instr->opcode] == instr_class::wmma) {
|
||||
assert(instr->operands.back().regClass() == instr->definitions[0].regClass());
|
||||
|
||||
bool is_swmma = instr->operands.size() == 4;
|
||||
if (test_vgpr_bitset(ctx.vgpr_written_by_wmma, instr->operands[0]) ||
|
||||
test_vgpr_bitset(ctx.vgpr_written_by_wmma, instr->operands[1]) ||
|
||||
(is_swmma && test_vgpr_bitset(ctx.vgpr_written_by_wmma, instr->operands[2]))) {
|
||||
bld.vop1(aco_opcode::v_nop);
|
||||
}
|
||||
|
||||
ctx.vgpr_written_by_wmma.reset();
|
||||
fill_vgpr_bitset(ctx.vgpr_written_by_wmma, instr->definitions[0].physReg(),
|
||||
instr->definitions[0].bytes());
|
||||
} else if (instr->isVALU()) {
|
||||
ctx.vgpr_written_by_wmma.reset();
|
||||
}
|
||||
}
|
||||
|
||||
bool
|
||||
@@ -1619,9 +1654,10 @@ resolve_all_gfx11(State& state, NOP_ctx_gfx11& ctx,
|
||||
ctx.trans_since_wr_by_trans.reset();
|
||||
}
|
||||
|
||||
/* VcmpxPermlaneHazard */
|
||||
if (ctx.has_Vcmpx) {
|
||||
/* VcmpxPermlaneHazard/WMMAHazards */
|
||||
if (ctx.has_Vcmpx || ctx.vgpr_written_by_wmma.any()) {
|
||||
ctx.has_Vcmpx = false;
|
||||
ctx.vgpr_written_by_wmma.reset();
|
||||
bld.vop1(aco_opcode::v_nop);
|
||||
}
|
||||
|
||||
|
@@ -1129,6 +1129,79 @@ BEGIN_TEST(insert_nops.valu_mask_write)
|
||||
finish_insert_nops_test();
|
||||
END_TEST
|
||||
|
||||
BEGIN_TEST(insert_nops.wmma_raw)
|
||||
if (!setup_cs(NULL, GFX11))
|
||||
return;
|
||||
|
||||
/* Basic case. */
|
||||
//>> p_unit_test 0
|
||||
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[0-7].xx, %_:v[8-15].xx, %_:v[20-23].xx
|
||||
//! v_nop
|
||||
//! v4: %_:v[48-51] = v_wmma_f16_16x16x16_f16 %_:v[24-31].xx, %_:v[16-23].xx, %_:v[48-51].xx
|
||||
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(0));
|
||||
Operand A(PhysReg(256 + 0), v8);
|
||||
Operand B(PhysReg(256 + 8), v8);
|
||||
Operand C(PhysReg(256 + 20), v4);
|
||||
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||
0);
|
||||
A.setFixed(PhysReg(256 + 24));
|
||||
B.setFixed(PhysReg(256 + 16));
|
||||
C.setFixed(PhysReg(256 + 48));
|
||||
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||
0);
|
||||
|
||||
/* Mitigation. */
|
||||
//! p_unit_test 1
|
||||
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[0-7].xx, %_:v[8-15].xx, %_:v[20-23].xx
|
||||
//! v1: %_:v[56] = v_rcp_f32 0
|
||||
//! v4: %_:v[48-51] = v_wmma_f16_16x16x16_f16 %_:v[24-31].xx, %_:v[16-23].xx, %_:v[48-51].xx
|
||||
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(1));
|
||||
A.setFixed(PhysReg(256 + 0));
|
||||
B.setFixed(PhysReg(256 + 8));
|
||||
C.setFixed(PhysReg(256 + 20));
|
||||
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||
0);
|
||||
bld.vop1(aco_opcode::v_rcp_f32, Definition(PhysReg(256 + 56), v1), Operand::zero());
|
||||
A.setFixed(PhysReg(256 + 24));
|
||||
B.setFixed(PhysReg(256 + 16));
|
||||
C.setFixed(PhysReg(256 + 48));
|
||||
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||
0);
|
||||
|
||||
/* No hazard. */
|
||||
//>> p_unit_test 2
|
||||
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[0-7].xx, %_:v[8-15].xx, %_:v[20-23].xx
|
||||
//! v4: %_:v[48-51] = v_wmma_f16_16x16x16_f16 %_:v[24-31].xx, %_:v[32-39].xx, %_:v[48-51].xx
|
||||
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(2));
|
||||
A.setFixed(PhysReg(256 + 0));
|
||||
B.setFixed(PhysReg(256 + 8));
|
||||
C.setFixed(PhysReg(256 + 20));
|
||||
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||
0);
|
||||
A.setFixed(PhysReg(256 + 24));
|
||||
B.setFixed(PhysReg(256 + 32));
|
||||
C.setFixed(PhysReg(256 + 48));
|
||||
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||
0);
|
||||
|
||||
//>> p_unit_test 3
|
||||
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[0-7].xx, %_:v[8-15].xx, %_:v[20-23].xx
|
||||
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[24-31].xx, %_:v[32-39].xx, %_:v[20-23].xx
|
||||
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(3));
|
||||
A.setFixed(PhysReg(256 + 0));
|
||||
B.setFixed(PhysReg(256 + 8));
|
||||
C.setFixed(PhysReg(256 + 20));
|
||||
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||
0);
|
||||
A.setFixed(PhysReg(256 + 24));
|
||||
B.setFixed(PhysReg(256 + 32));
|
||||
C.setFixed(PhysReg(256 + 20));
|
||||
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||
0);
|
||||
|
||||
finish_insert_nops_test();
|
||||
END_TEST
|
||||
|
||||
BEGIN_TEST(insert_nops.setpc_gfx6)
|
||||
if (!setup_cs(NULL, GFX6))
|
||||
return;
|
||||
@@ -1447,6 +1520,20 @@ BEGIN_TEST(insert_nops.setpc_gfx11)
|
||||
bld.ds(aco_opcode::ds_read_b32, Definition(PhysReg(256), v1), Operand(PhysReg(256), v1));
|
||||
bld.sop1(aco_opcode::s_setpc_b64, Operand::zero(8));
|
||||
|
||||
/* WMMA Hazards */
|
||||
//! p_unit_test 7
|
||||
//! v4: %0:v[20-23] = v_wmma_f16_16x16x16_f16 %0:v[0-7].xx, %0:v[8-15].xx, %0:v[20-23].xx
|
||||
//! v_nop
|
||||
//! s_waitcnt_depctr va_vdst(0)
|
||||
//! s_setpc_b64 0
|
||||
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(7));
|
||||
Operand A(PhysReg(256 + 0), v8);
|
||||
Operand B(PhysReg(256 + 8), v8);
|
||||
Operand C(PhysReg(256 + 20), v4);
|
||||
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||
0);
|
||||
bld.sop1(aco_opcode::s_setpc_b64, Operand::zero(8));
|
||||
|
||||
finish_insert_nops_test(true);
|
||||
END_TEST
|
||||
|
||||
|
Reference in New Issue
Block a user