aco: fix wmma raw hazard
No fossil-db changes. Signed-off-by: Rhys Perry <pendingchaos02@gmail.com> Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29912>
This commit is contained in:
@@ -266,6 +266,9 @@ struct NOP_ctx_gfx11 {
|
|||||||
std::bitset<128> sgpr_read_by_valu_as_lanemask;
|
std::bitset<128> sgpr_read_by_valu_as_lanemask;
|
||||||
std::bitset<128> sgpr_read_by_valu_as_lanemask_then_wr_by_salu;
|
std::bitset<128> sgpr_read_by_valu_as_lanemask_then_wr_by_salu;
|
||||||
|
|
||||||
|
/* WMMAHazards */
|
||||||
|
std::bitset<256> vgpr_written_by_wmma;
|
||||||
|
|
||||||
void join(const NOP_ctx_gfx11& other)
|
void join(const NOP_ctx_gfx11& other)
|
||||||
{
|
{
|
||||||
has_Vcmpx |= other.has_Vcmpx;
|
has_Vcmpx |= other.has_Vcmpx;
|
||||||
@@ -279,6 +282,7 @@ struct NOP_ctx_gfx11 {
|
|||||||
sgpr_read_by_valu_as_lanemask |= other.sgpr_read_by_valu_as_lanemask;
|
sgpr_read_by_valu_as_lanemask |= other.sgpr_read_by_valu_as_lanemask;
|
||||||
sgpr_read_by_valu_as_lanemask_then_wr_by_salu |=
|
sgpr_read_by_valu_as_lanemask_then_wr_by_salu |=
|
||||||
other.sgpr_read_by_valu_as_lanemask_then_wr_by_salu;
|
other.sgpr_read_by_valu_as_lanemask_then_wr_by_salu;
|
||||||
|
vgpr_written_by_wmma |= other.vgpr_written_by_wmma;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(const NOP_ctx_gfx11& other)
|
bool operator==(const NOP_ctx_gfx11& other)
|
||||||
@@ -293,7 +297,8 @@ struct NOP_ctx_gfx11 {
|
|||||||
trans_since_wr_by_trans == other.trans_since_wr_by_trans &&
|
trans_since_wr_by_trans == other.trans_since_wr_by_trans &&
|
||||||
sgpr_read_by_valu_as_lanemask == other.sgpr_read_by_valu_as_lanemask &&
|
sgpr_read_by_valu_as_lanemask == other.sgpr_read_by_valu_as_lanemask &&
|
||||||
sgpr_read_by_valu_as_lanemask_then_wr_by_salu ==
|
sgpr_read_by_valu_as_lanemask_then_wr_by_salu ==
|
||||||
other.sgpr_read_by_valu_as_lanemask_then_wr_by_salu;
|
other.sgpr_read_by_valu_as_lanemask_then_wr_by_salu &&
|
||||||
|
vgpr_written_by_wmma == other.vgpr_written_by_wmma;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1118,6 +1123,18 @@ fill_vgpr_bitset(std::bitset<256>& set, PhysReg reg, unsigned bytes)
|
|||||||
set.set(reg.reg() - 256 + i);
|
set.set(reg.reg() - 256 + i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool
|
||||||
|
test_vgpr_bitset(std::bitset<256>& set, Operand op)
|
||||||
|
{
|
||||||
|
if (op.physReg().reg() < 256)
|
||||||
|
return false;
|
||||||
|
for (unsigned i = 0; i < op.size(); i++) {
|
||||||
|
if (set[op.physReg().reg() - 256 + i])
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
/* GFX11 */
|
/* GFX11 */
|
||||||
unsigned
|
unsigned
|
||||||
parse_vdst_wait(aco_ptr<Instruction>& instr)
|
parse_vdst_wait(aco_ptr<Instruction>& instr)
|
||||||
@@ -1568,6 +1585,24 @@ handle_instruction_gfx11(State& state, NOP_ctx_gfx11& ctx, aco_ptr<Instruction>&
|
|||||||
ctx.vgpr_used_by_ds.reset();
|
ctx.vgpr_used_by_ds.reset();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* WMMA Hazards */
|
||||||
|
if (instr_info.classes[(int)instr->opcode] == instr_class::wmma) {
|
||||||
|
assert(instr->operands.back().regClass() == instr->definitions[0].regClass());
|
||||||
|
|
||||||
|
bool is_swmma = instr->operands.size() == 4;
|
||||||
|
if (test_vgpr_bitset(ctx.vgpr_written_by_wmma, instr->operands[0]) ||
|
||||||
|
test_vgpr_bitset(ctx.vgpr_written_by_wmma, instr->operands[1]) ||
|
||||||
|
(is_swmma && test_vgpr_bitset(ctx.vgpr_written_by_wmma, instr->operands[2]))) {
|
||||||
|
bld.vop1(aco_opcode::v_nop);
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.vgpr_written_by_wmma.reset();
|
||||||
|
fill_vgpr_bitset(ctx.vgpr_written_by_wmma, instr->definitions[0].physReg(),
|
||||||
|
instr->definitions[0].bytes());
|
||||||
|
} else if (instr->isVALU()) {
|
||||||
|
ctx.vgpr_written_by_wmma.reset();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool
|
bool
|
||||||
@@ -1619,9 +1654,10 @@ resolve_all_gfx11(State& state, NOP_ctx_gfx11& ctx,
|
|||||||
ctx.trans_since_wr_by_trans.reset();
|
ctx.trans_since_wr_by_trans.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* VcmpxPermlaneHazard */
|
/* VcmpxPermlaneHazard/WMMAHazards */
|
||||||
if (ctx.has_Vcmpx) {
|
if (ctx.has_Vcmpx || ctx.vgpr_written_by_wmma.any()) {
|
||||||
ctx.has_Vcmpx = false;
|
ctx.has_Vcmpx = false;
|
||||||
|
ctx.vgpr_written_by_wmma.reset();
|
||||||
bld.vop1(aco_opcode::v_nop);
|
bld.vop1(aco_opcode::v_nop);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1129,6 +1129,79 @@ BEGIN_TEST(insert_nops.valu_mask_write)
|
|||||||
finish_insert_nops_test();
|
finish_insert_nops_test();
|
||||||
END_TEST
|
END_TEST
|
||||||
|
|
||||||
|
BEGIN_TEST(insert_nops.wmma_raw)
|
||||||
|
if (!setup_cs(NULL, GFX11))
|
||||||
|
return;
|
||||||
|
|
||||||
|
/* Basic case. */
|
||||||
|
//>> p_unit_test 0
|
||||||
|
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[0-7].xx, %_:v[8-15].xx, %_:v[20-23].xx
|
||||||
|
//! v_nop
|
||||||
|
//! v4: %_:v[48-51] = v_wmma_f16_16x16x16_f16 %_:v[24-31].xx, %_:v[16-23].xx, %_:v[48-51].xx
|
||||||
|
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(0));
|
||||||
|
Operand A(PhysReg(256 + 0), v8);
|
||||||
|
Operand B(PhysReg(256 + 8), v8);
|
||||||
|
Operand C(PhysReg(256 + 20), v4);
|
||||||
|
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||||
|
0);
|
||||||
|
A.setFixed(PhysReg(256 + 24));
|
||||||
|
B.setFixed(PhysReg(256 + 16));
|
||||||
|
C.setFixed(PhysReg(256 + 48));
|
||||||
|
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||||
|
0);
|
||||||
|
|
||||||
|
/* Mitigation. */
|
||||||
|
//! p_unit_test 1
|
||||||
|
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[0-7].xx, %_:v[8-15].xx, %_:v[20-23].xx
|
||||||
|
//! v1: %_:v[56] = v_rcp_f32 0
|
||||||
|
//! v4: %_:v[48-51] = v_wmma_f16_16x16x16_f16 %_:v[24-31].xx, %_:v[16-23].xx, %_:v[48-51].xx
|
||||||
|
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(1));
|
||||||
|
A.setFixed(PhysReg(256 + 0));
|
||||||
|
B.setFixed(PhysReg(256 + 8));
|
||||||
|
C.setFixed(PhysReg(256 + 20));
|
||||||
|
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||||
|
0);
|
||||||
|
bld.vop1(aco_opcode::v_rcp_f32, Definition(PhysReg(256 + 56), v1), Operand::zero());
|
||||||
|
A.setFixed(PhysReg(256 + 24));
|
||||||
|
B.setFixed(PhysReg(256 + 16));
|
||||||
|
C.setFixed(PhysReg(256 + 48));
|
||||||
|
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||||
|
0);
|
||||||
|
|
||||||
|
/* No hazard. */
|
||||||
|
//>> p_unit_test 2
|
||||||
|
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[0-7].xx, %_:v[8-15].xx, %_:v[20-23].xx
|
||||||
|
//! v4: %_:v[48-51] = v_wmma_f16_16x16x16_f16 %_:v[24-31].xx, %_:v[32-39].xx, %_:v[48-51].xx
|
||||||
|
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(2));
|
||||||
|
A.setFixed(PhysReg(256 + 0));
|
||||||
|
B.setFixed(PhysReg(256 + 8));
|
||||||
|
C.setFixed(PhysReg(256 + 20));
|
||||||
|
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||||
|
0);
|
||||||
|
A.setFixed(PhysReg(256 + 24));
|
||||||
|
B.setFixed(PhysReg(256 + 32));
|
||||||
|
C.setFixed(PhysReg(256 + 48));
|
||||||
|
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||||
|
0);
|
||||||
|
|
||||||
|
//>> p_unit_test 3
|
||||||
|
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[0-7].xx, %_:v[8-15].xx, %_:v[20-23].xx
|
||||||
|
//! v4: %_:v[20-23] = v_wmma_f16_16x16x16_f16 %_:v[24-31].xx, %_:v[32-39].xx, %_:v[20-23].xx
|
||||||
|
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(3));
|
||||||
|
A.setFixed(PhysReg(256 + 0));
|
||||||
|
B.setFixed(PhysReg(256 + 8));
|
||||||
|
C.setFixed(PhysReg(256 + 20));
|
||||||
|
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||||
|
0);
|
||||||
|
A.setFixed(PhysReg(256 + 24));
|
||||||
|
B.setFixed(PhysReg(256 + 32));
|
||||||
|
C.setFixed(PhysReg(256 + 20));
|
||||||
|
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||||
|
0);
|
||||||
|
|
||||||
|
finish_insert_nops_test();
|
||||||
|
END_TEST
|
||||||
|
|
||||||
BEGIN_TEST(insert_nops.setpc_gfx6)
|
BEGIN_TEST(insert_nops.setpc_gfx6)
|
||||||
if (!setup_cs(NULL, GFX6))
|
if (!setup_cs(NULL, GFX6))
|
||||||
return;
|
return;
|
||||||
@@ -1447,6 +1520,20 @@ BEGIN_TEST(insert_nops.setpc_gfx11)
|
|||||||
bld.ds(aco_opcode::ds_read_b32, Definition(PhysReg(256), v1), Operand(PhysReg(256), v1));
|
bld.ds(aco_opcode::ds_read_b32, Definition(PhysReg(256), v1), Operand(PhysReg(256), v1));
|
||||||
bld.sop1(aco_opcode::s_setpc_b64, Operand::zero(8));
|
bld.sop1(aco_opcode::s_setpc_b64, Operand::zero(8));
|
||||||
|
|
||||||
|
/* WMMA Hazards */
|
||||||
|
//! p_unit_test 7
|
||||||
|
//! v4: %0:v[20-23] = v_wmma_f16_16x16x16_f16 %0:v[0-7].xx, %0:v[8-15].xx, %0:v[20-23].xx
|
||||||
|
//! v_nop
|
||||||
|
//! s_waitcnt_depctr va_vdst(0)
|
||||||
|
//! s_setpc_b64 0
|
||||||
|
bld.pseudo(aco_opcode::p_unit_test, Operand::c32(7));
|
||||||
|
Operand A(PhysReg(256 + 0), v8);
|
||||||
|
Operand B(PhysReg(256 + 8), v8);
|
||||||
|
Operand C(PhysReg(256 + 20), v4);
|
||||||
|
bld.vop3p(aco_opcode::v_wmma_f16_16x16x16_f16, Definition(C.physReg(), C.regClass()), A, B, C, 0,
|
||||||
|
0);
|
||||||
|
bld.sop1(aco_opcode::s_setpc_b64, Operand::zero(8));
|
||||||
|
|
||||||
finish_insert_nops_test(true);
|
finish_insert_nops_test(true);
|
||||||
END_TEST
|
END_TEST
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user