aco: Add WMMA instructions.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24683>
This commit is contained in:
Bas Nieuwenhuizen
2023-07-15 19:49:49 +02:00
committed by Marge Bot
parent a29cd20d17
commit 5e7c828c0e
8 changed files with 77 additions and 6 deletions

View File

@@ -875,7 +875,8 @@ gen_alu(Instruction* instr, wait_ctx& ctx)
for (const Definition& def : instr->definitions)
insert_wait_entry(ctx, def, event, 0, cycle_info.latency);
}
update_alu(ctx, is_valu, is_trans, clear, cycle_info.issue_cycles);
update_alu(ctx, is_valu && instr_info.classes[(int)instr->opcode] != instr_class::wmma, is_trans,
clear, cycle_info.issue_cycles);
}
void

View File

@@ -8087,6 +8087,48 @@ create_fs_dual_src_export_gfx11(isel_context* ctx, const struct aco_export_mrt*
ctx->program->has_color_exports = true;
}
static void
visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr)
{
aco_opcode opcode = aco_opcode::num_opcodes;
unsigned signed_mask = 0;
bool clamp = false;
switch (instr->src[0].ssa->bit_size) {
case 16:
switch (instr->def.bit_size) {
case 32: opcode = aco_opcode::v_wmma_f32_16x16x16_f16; break;
case 16: opcode = aco_opcode::v_wmma_f16_16x16x16_f16; break;
}
break;
case 8:
opcode = aco_opcode::v_wmma_i32_16x16x16_iu8;
signed_mask = nir_intrinsic_cmat_signed_mask(instr);
clamp = nir_intrinsic_saturate(instr);
break;
}
if (opcode == aco_opcode::num_opcodes)
unreachable("visit_cmat_muladd: invalid bit size combination");
Builder bld(ctx->program, ctx->block);
Temp dst = get_ssa_temp(ctx, &instr->def);
Operand A(as_vgpr(ctx, get_ssa_temp(ctx, instr->src[0].ssa)));
Operand B(as_vgpr(ctx, get_ssa_temp(ctx, instr->src[1].ssa)));
Operand C(as_vgpr(ctx, get_ssa_temp(ctx, instr->src[2].ssa)));
A.setLateKill(true);
B.setLateKill(true);
VALU_instruction& vop3p = bld.vop3p(opcode, Definition(dst), A, B, C, 0, 0)->valu();
vop3p.neg_lo[0] = (signed_mask & 0x1) != 0;
vop3p.neg_lo[1] = (signed_mask & 0x2) != 0;
vop3p.clamp = clamp;
emit_split_vector(ctx, dst, instr->def.num_components);
}
void
visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
{
@@ -9174,6 +9216,7 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
bld.pseudo(aco_opcode::p_pops_gfx9_ordered_section_done);
break;
}
case nir_intrinsic_cmat_muladd_amd: visit_cmat_muladd(ctx, instr); break;
default:
isel_err(&instr->instr, "Unimplemented intrinsic instr");
abort();

View File

@@ -535,7 +535,8 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_intrinsic_bvh64_intersect_ray_amd:
case nir_intrinsic_load_vector_arg_amd:
case nir_intrinsic_load_rt_dynamic_callable_stack_base_amd:
case nir_intrinsic_ordered_xfb_counter_add_amd: type = RegType::vgpr; break;
case nir_intrinsic_ordered_xfb_counter_add_amd:
case nir_intrinsic_cmat_muladd_amd: type = RegType::vgpr; break;
case nir_intrinsic_load_shared:
case nir_intrinsic_load_shared2_amd:
/* When the result of these loads is only used by cross-lane instructions,

View File

@@ -133,6 +133,7 @@ enum class instr_class : uint8_t {
vmem = 17,
waitcnt = 18,
other = 19,
wmma = 20,
count,
};

View File

@@ -48,6 +48,7 @@ class InstrClass(Enum):
VMem = 17
Waitcnt = 18
Other = 19
WMMA = 20
class Format(Enum):
PSEUDO = 0
@@ -1051,6 +1052,12 @@ opcode("v_dot8_i32_iu4", -1, -1, -1, 0x18, Format.VOP3P, InstrClass.Valu32)
opcode("v_dot8_u32_u4", -1, 0x2b, 0x19, 0x19, Format.VOP3P, InstrClass.Valu32)
opcode("v_dot2_f32_f16", -1, 0x23, 0x13, 0x13, Format.VOP3P, InstrClass.Valu32)
opcode("v_dot2_f32_bf16", -1, -1, -1, 0x1a, Format.VOP3P, InstrClass.Valu32)
opcode("v_wmma_f32_16x16x16_f16", -1, -1, -1, 0x40, Format.VOP3P, InstrClass.WMMA, False, False)
opcode("v_wmma_f32_16x16x16_bf16", -1, -1, -1, 0x41, Format.VOP3P, InstrClass.WMMA, False, False)
opcode("v_wmma_f16_16x16x16_f16", -1, -1, -1, 0x42, Format.VOP3P, InstrClass.WMMA, False, False)
opcode("v_wmma_bf16_16x16x16_bf16", -1, -1, -1, 0x43, Format.VOP3P, InstrClass.WMMA, False, False)
opcode("v_wmma_i32_16x16x16_iu8", -1, -1, -1, 0x44, Format.VOP3P, InstrClass.WMMA, False, False)
opcode("v_wmma_i32_16x16x16_iu4", -1, -1, -1, 0x45, Format.VOP3P, InstrClass.WMMA, False, False)
# VINTRP (GFX6 - GFX10.3) instructions:

View File

@@ -643,7 +643,13 @@ can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
instr->opcode != aco_opcode::v_interp_p10_f16_f32_inreg &&
instr->opcode != aco_opcode::v_interp_p2_f16_f32_inreg &&
instr->opcode != aco_opcode::v_interp_p10_rtz_f16_f32_inreg &&
instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg;
instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg &&
instr->opcode != aco_opcode::v_wmma_f32_16x16x16_f16 &&
instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf16 &&
instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 &&
instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 &&
instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 &&
instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4;
}
bool
@@ -697,7 +703,13 @@ alu_can_accept_constant(const aco_ptr<Instruction>& instr, unsigned operand)
case aco_opcode::v_interp_p10_f16_f32_inreg:
case aco_opcode::v_interp_p2_f16_f32_inreg:
case aco_opcode::v_interp_p10_rtz_f16_f32_inreg:
case aco_opcode::v_interp_p2_rtz_f16_f32_inreg: return false;
case aco_opcode::v_interp_p2_rtz_f16_f32_inreg:
case aco_opcode::v_wmma_f32_16x16x16_f16:
case aco_opcode::v_wmma_f32_16x16x16_bf16:
case aco_opcode::v_wmma_f16_16x16x16_f16:
case aco_opcode::v_wmma_bf16_16x16x16_bf16:
case aco_opcode::v_wmma_i32_16x16x16_iu8:
case aco_opcode::v_wmma_i32_16x16x16_iu4: return false;
default: return true;
}
}

View File

@@ -223,6 +223,11 @@ get_perf_info(const Program& program, const Instruction& instr)
: perf_info{0, WAIT_USE(lds, 1)};
case instr_class::exp: return {0, WAIT_USE(export_gds, 1)};
case instr_class::vmem: return {0, WAIT_USE(vmem, 1)};
case instr_class::wmma: {
/* int8 and (b)f16 have the same performance. */
uint8_t cost = instr.opcode == aco_opcode::v_wmma_i32_16x16x16_iu4 ? 16 : 32;
return {cost, WAIT_USE(valu, cost)};
}
case instr_class::barrier:
case instr_class::waitcnt:
case instr_class::other:

View File

@@ -259,8 +259,9 @@ validate_ir(Program* program)
check(!vop3p.opsel_lo[i] && !vop3p.opsel_hi[i],
"Unexpected opsel for subdword operand", instr.get());
}
check(instr->definitions[0].regClass() == v1, "VOP3P must have v1 definition",
instr.get());
check(instr->definitions[0].regClass() == v1 ||
instr_info.classes[(int)instr->opcode] == instr_class::wmma,
"VOP3P must have v1 definition", instr.get());
}
/* check for undefs */