aco: Add WMMA instructions.
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24683>
This commit is contained in:

committed by
Marge Bot

parent
a29cd20d17
commit
5e7c828c0e
@@ -875,7 +875,8 @@ gen_alu(Instruction* instr, wait_ctx& ctx)
|
||||
for (const Definition& def : instr->definitions)
|
||||
insert_wait_entry(ctx, def, event, 0, cycle_info.latency);
|
||||
}
|
||||
update_alu(ctx, is_valu, is_trans, clear, cycle_info.issue_cycles);
|
||||
update_alu(ctx, is_valu && instr_info.classes[(int)instr->opcode] != instr_class::wmma, is_trans,
|
||||
clear, cycle_info.issue_cycles);
|
||||
}
|
||||
|
||||
void
|
||||
|
@@ -8087,6 +8087,48 @@ create_fs_dual_src_export_gfx11(isel_context* ctx, const struct aco_export_mrt*
|
||||
ctx->program->has_color_exports = true;
|
||||
}
|
||||
|
||||
static void
|
||||
visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr)
|
||||
{
|
||||
aco_opcode opcode = aco_opcode::num_opcodes;
|
||||
unsigned signed_mask = 0;
|
||||
bool clamp = false;
|
||||
|
||||
switch (instr->src[0].ssa->bit_size) {
|
||||
case 16:
|
||||
switch (instr->def.bit_size) {
|
||||
case 32: opcode = aco_opcode::v_wmma_f32_16x16x16_f16; break;
|
||||
case 16: opcode = aco_opcode::v_wmma_f16_16x16x16_f16; break;
|
||||
}
|
||||
break;
|
||||
case 8:
|
||||
opcode = aco_opcode::v_wmma_i32_16x16x16_iu8;
|
||||
signed_mask = nir_intrinsic_cmat_signed_mask(instr);
|
||||
clamp = nir_intrinsic_saturate(instr);
|
||||
break;
|
||||
}
|
||||
|
||||
if (opcode == aco_opcode::num_opcodes)
|
||||
unreachable("visit_cmat_muladd: invalid bit size combination");
|
||||
|
||||
Builder bld(ctx->program, ctx->block);
|
||||
|
||||
Temp dst = get_ssa_temp(ctx, &instr->def);
|
||||
Operand A(as_vgpr(ctx, get_ssa_temp(ctx, instr->src[0].ssa)));
|
||||
Operand B(as_vgpr(ctx, get_ssa_temp(ctx, instr->src[1].ssa)));
|
||||
Operand C(as_vgpr(ctx, get_ssa_temp(ctx, instr->src[2].ssa)));
|
||||
|
||||
A.setLateKill(true);
|
||||
B.setLateKill(true);
|
||||
|
||||
VALU_instruction& vop3p = bld.vop3p(opcode, Definition(dst), A, B, C, 0, 0)->valu();
|
||||
vop3p.neg_lo[0] = (signed_mask & 0x1) != 0;
|
||||
vop3p.neg_lo[1] = (signed_mask & 0x2) != 0;
|
||||
vop3p.clamp = clamp;
|
||||
|
||||
emit_split_vector(ctx, dst, instr->def.num_components);
|
||||
}
|
||||
|
||||
void
|
||||
visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
|
||||
{
|
||||
@@ -9174,6 +9216,7 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
|
||||
bld.pseudo(aco_opcode::p_pops_gfx9_ordered_section_done);
|
||||
break;
|
||||
}
|
||||
case nir_intrinsic_cmat_muladd_amd: visit_cmat_muladd(ctx, instr); break;
|
||||
default:
|
||||
isel_err(&instr->instr, "Unimplemented intrinsic instr");
|
||||
abort();
|
||||
|
@@ -535,7 +535,8 @@ init_context(isel_context* ctx, nir_shader* shader)
|
||||
case nir_intrinsic_bvh64_intersect_ray_amd:
|
||||
case nir_intrinsic_load_vector_arg_amd:
|
||||
case nir_intrinsic_load_rt_dynamic_callable_stack_base_amd:
|
||||
case nir_intrinsic_ordered_xfb_counter_add_amd: type = RegType::vgpr; break;
|
||||
case nir_intrinsic_ordered_xfb_counter_add_amd:
|
||||
case nir_intrinsic_cmat_muladd_amd: type = RegType::vgpr; break;
|
||||
case nir_intrinsic_load_shared:
|
||||
case nir_intrinsic_load_shared2_amd:
|
||||
/* When the result of these loads is only used by cross-lane instructions,
|
||||
|
@@ -133,6 +133,7 @@ enum class instr_class : uint8_t {
|
||||
vmem = 17,
|
||||
waitcnt = 18,
|
||||
other = 19,
|
||||
wmma = 20,
|
||||
count,
|
||||
};
|
||||
|
||||
|
@@ -48,6 +48,7 @@ class InstrClass(Enum):
|
||||
VMem = 17
|
||||
Waitcnt = 18
|
||||
Other = 19
|
||||
WMMA = 20
|
||||
|
||||
class Format(Enum):
|
||||
PSEUDO = 0
|
||||
@@ -1051,6 +1052,12 @@ opcode("v_dot8_i32_iu4", -1, -1, -1, 0x18, Format.VOP3P, InstrClass.Valu32)
|
||||
opcode("v_dot8_u32_u4", -1, 0x2b, 0x19, 0x19, Format.VOP3P, InstrClass.Valu32)
|
||||
opcode("v_dot2_f32_f16", -1, 0x23, 0x13, 0x13, Format.VOP3P, InstrClass.Valu32)
|
||||
opcode("v_dot2_f32_bf16", -1, -1, -1, 0x1a, Format.VOP3P, InstrClass.Valu32)
|
||||
opcode("v_wmma_f32_16x16x16_f16", -1, -1, -1, 0x40, Format.VOP3P, InstrClass.WMMA, False, False)
|
||||
opcode("v_wmma_f32_16x16x16_bf16", -1, -1, -1, 0x41, Format.VOP3P, InstrClass.WMMA, False, False)
|
||||
opcode("v_wmma_f16_16x16x16_f16", -1, -1, -1, 0x42, Format.VOP3P, InstrClass.WMMA, False, False)
|
||||
opcode("v_wmma_bf16_16x16x16_bf16", -1, -1, -1, 0x43, Format.VOP3P, InstrClass.WMMA, False, False)
|
||||
opcode("v_wmma_i32_16x16x16_iu8", -1, -1, -1, 0x44, Format.VOP3P, InstrClass.WMMA, False, False)
|
||||
opcode("v_wmma_i32_16x16x16_iu4", -1, -1, -1, 0x45, Format.VOP3P, InstrClass.WMMA, False, False)
|
||||
|
||||
|
||||
# VINTRP (GFX6 - GFX10.3) instructions:
|
||||
|
@@ -643,7 +643,13 @@ can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
instr->opcode != aco_opcode::v_interp_p10_f16_f32_inreg &&
|
||||
instr->opcode != aco_opcode::v_interp_p2_f16_f32_inreg &&
|
||||
instr->opcode != aco_opcode::v_interp_p10_rtz_f16_f32_inreg &&
|
||||
instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg;
|
||||
instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg &&
|
||||
instr->opcode != aco_opcode::v_wmma_f32_16x16x16_f16 &&
|
||||
instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf16 &&
|
||||
instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 &&
|
||||
instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 &&
|
||||
instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 &&
|
||||
instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4;
|
||||
}
|
||||
|
||||
bool
|
||||
@@ -697,7 +703,13 @@ alu_can_accept_constant(const aco_ptr<Instruction>& instr, unsigned operand)
|
||||
case aco_opcode::v_interp_p10_f16_f32_inreg:
|
||||
case aco_opcode::v_interp_p2_f16_f32_inreg:
|
||||
case aco_opcode::v_interp_p10_rtz_f16_f32_inreg:
|
||||
case aco_opcode::v_interp_p2_rtz_f16_f32_inreg: return false;
|
||||
case aco_opcode::v_interp_p2_rtz_f16_f32_inreg:
|
||||
case aco_opcode::v_wmma_f32_16x16x16_f16:
|
||||
case aco_opcode::v_wmma_f32_16x16x16_bf16:
|
||||
case aco_opcode::v_wmma_f16_16x16x16_f16:
|
||||
case aco_opcode::v_wmma_bf16_16x16x16_bf16:
|
||||
case aco_opcode::v_wmma_i32_16x16x16_iu8:
|
||||
case aco_opcode::v_wmma_i32_16x16x16_iu4: return false;
|
||||
default: return true;
|
||||
}
|
||||
}
|
||||
|
@@ -223,6 +223,11 @@ get_perf_info(const Program& program, const Instruction& instr)
|
||||
: perf_info{0, WAIT_USE(lds, 1)};
|
||||
case instr_class::exp: return {0, WAIT_USE(export_gds, 1)};
|
||||
case instr_class::vmem: return {0, WAIT_USE(vmem, 1)};
|
||||
case instr_class::wmma: {
|
||||
/* int8 and (b)f16 have the same performance. */
|
||||
uint8_t cost = instr.opcode == aco_opcode::v_wmma_i32_16x16x16_iu4 ? 16 : 32;
|
||||
return {cost, WAIT_USE(valu, cost)};
|
||||
}
|
||||
case instr_class::barrier:
|
||||
case instr_class::waitcnt:
|
||||
case instr_class::other:
|
||||
|
@@ -259,8 +259,9 @@ validate_ir(Program* program)
|
||||
check(!vop3p.opsel_lo[i] && !vop3p.opsel_hi[i],
|
||||
"Unexpected opsel for subdword operand", instr.get());
|
||||
}
|
||||
check(instr->definitions[0].regClass() == v1, "VOP3P must have v1 definition",
|
||||
instr.get());
|
||||
check(instr->definitions[0].regClass() == v1 ||
|
||||
instr_info.classes[(int)instr->opcode] == instr_class::wmma,
|
||||
"VOP3P must have v1 definition", instr.get());
|
||||
}
|
||||
|
||||
/* check for undefs */
|
||||
|
Reference in New Issue
Block a user