aco: combine add/mul as v_fma_mix into fma
fossil-db (Sienna Cichlid): Totals from 7345 (5.44% of 134913) affected shaders: CodeSize: 73840060 -> 73768936 (-0.10%); split: -0.10%, +0.00% Instrs: 13701603 -> 13684183 (-0.13%); split: -0.13%, +0.00% Latency: 185389373 -> 185306538 (-0.04%); split: -0.04%, +0.00% InvThroughput: 33785020 -> 33757593 (-0.08%); split: -0.08%, +0.00% VClause: 237337 -> 237338 (+0.00%) SClause: 485728 -> 485720 (-0.00%) Copies: 935900 -> 935279 (-0.07%); split: -0.07%, +0.00% Branches: 480721 -> 480722 (+0.00%) fossil-db (Navi): Totals from 10649 (7.89% of 134913) affected shaders: VGPRs: 756624 -> 756516 (-0.01%); split: -0.02%, +0.01% CodeSize: 92156580 -> 91707900 (-0.49%); split: -0.49%, +0.00% MaxWaves: 159402 -> 159476 (+0.05%); split: +0.07%, -0.02% Instrs: 17155827 -> 17070449 (-0.50%); split: -0.50%, +0.00% Latency: 246296456 -> 245487120 (-0.33%); split: -0.33%, +0.00% InvThroughput: 41438159 -> 41117424 (-0.77%); split: -0.77%, +0.00% VClause: 323790 -> 323867 (+0.02%); split: -0.00%, +0.03% SClause: 612077 -> 612034 (-0.01%); split: -0.01%, +0.00% Copies: 1103012 -> 1102775 (-0.02%); split: -0.03%, +0.01% Branches: 555893 -> 555896 (+0.00%); split: -0.00%, +0.00% PreSGPRs: 824372 -> 824378 (+0.00%) PreVGPRs: 740390 -> 740363 (-0.00%); split: -0.01%, +0.01% fossil-db (Vega): Totals from 10950 (8.11% of 135048) affected shaders: SGPRs: 1034528 -> 1034560 (+0.00%) VGPRs: 794092 -> 794104 (+0.00%); split: -0.01%, +0.01% CodeSize: 94409768 -> 93955568 (-0.48%); split: -0.48%, +0.00% MaxWaves: 38950 -> 38939 (-0.03%); split: +0.00%, -0.03% Instrs: 18162637 -> 18070934 (-0.50%); split: -0.51%, +0.00% Latency: 291718455 -> 290772451 (-0.32%); split: -0.32%, +0.00% InvThroughput: 109114674 -> 108489767 (-0.57%); split: -0.57%, +0.00% VClause: 334498 -> 334579 (+0.02%); split: -0.01%, +0.03% SClause: 628871 -> 628825 (-0.01%); split: -0.01%, +0.00% Copies: 1674477 -> 1674850 (+0.02%); split: -0.02%, +0.04% PreSGPRs: 834800 -> 834802 (+0.00%) PreVGPRs: 750460 -> 750415 (-0.01%); split: -0.01%, +0.01% Signed-off-by: Rhys Perry <pendingchaos02@gmail.com> Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14769>
This commit is contained in:
@@ -3689,18 +3689,25 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
}
|
||||
|
||||
/* combine mul+add -> mad */
|
||||
bool is_add_mix =
|
||||
(instr->opcode == aco_opcode::v_fma_mix_f32 ||
|
||||
instr->opcode == aco_opcode::v_fma_mixlo_f16) &&
|
||||
!instr->vop3p().neg_lo[0] &&
|
||||
((instr->operands[0].constantEquals(0x3f800000) && (instr->vop3p().opsel_hi & 0x1) == 0) ||
|
||||
(instr->operands[0].constantEquals(0x3C00) && (instr->vop3p().opsel_hi & 0x1) &&
|
||||
!(instr->vop3p().opsel_lo & 0x1)));
|
||||
bool mad32 = instr->opcode == aco_opcode::v_add_f32 || instr->opcode == aco_opcode::v_sub_f32 ||
|
||||
instr->opcode == aco_opcode::v_subrev_f32;
|
||||
bool mad16 = instr->opcode == aco_opcode::v_add_f16 || instr->opcode == aco_opcode::v_sub_f16 ||
|
||||
instr->opcode == aco_opcode::v_subrev_f16;
|
||||
bool mad64 = instr->opcode == aco_opcode::v_add_f64;
|
||||
if (mad16 || mad32 || mad64) {
|
||||
if (is_add_mix || mad16 || mad32 || mad64) {
|
||||
Instruction* mul_instr = nullptr;
|
||||
unsigned add_op_idx = 0;
|
||||
uint32_t uses = UINT32_MAX;
|
||||
bool emit_fma = false;
|
||||
/* find the 'best' mul instruction to combine with the add */
|
||||
for (unsigned i = 0; i < 2; i++) {
|
||||
for (unsigned i = is_add_mix ? 1 : 0; i < instr->operands.size(); i++) {
|
||||
if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
|
||||
continue;
|
||||
ssa_info& info = ctx.info[instr->operands[i].tempId()];
|
||||
@@ -3708,26 +3715,39 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
/* no clamp/omod allowed between mul and add */
|
||||
if (info.instr->isVOP3() && (info.instr->vop3().clamp || info.instr->vop3().omod))
|
||||
continue;
|
||||
if (info.instr->isVOP3P())
|
||||
if (info.instr->isVOP3P() && info.instr->vop3p().clamp)
|
||||
continue;
|
||||
/* v_fma_mix_f32/etc can't do omod */
|
||||
if (info.instr->isVOP3P() && instr->isVOP3() && instr->vop3().omod)
|
||||
continue;
|
||||
/* don't promote fp16 to fp32 or remove fp32->fp16->fp32 conversions */
|
||||
if (is_add_mix && info.instr->definitions[0].bytes() == 2)
|
||||
continue;
|
||||
|
||||
if (get_operand_size(instr, i) != info.instr->definitions[0].bytes() * 8)
|
||||
continue;
|
||||
|
||||
bool legacy = info.instr->opcode == aco_opcode::v_mul_legacy_f32;
|
||||
bool mad_mix = is_add_mix || info.instr->isVOP3P();
|
||||
|
||||
bool has_fma = mad16 || mad64 || (legacy && ctx.program->chip_class >= GFX10_3) ||
|
||||
(mad32 && !legacy && ctx.program->dev.has_fast_fma32);
|
||||
bool has_mad = (mad32 && ctx.program->chip_class < GFX10_3) ||
|
||||
(mad16 && ctx.program->chip_class <= GFX9);
|
||||
(mad32 && !legacy && !mad_mix && ctx.program->dev.has_fast_fma32) ||
|
||||
(mad_mix && ctx.program->dev.fused_mad_mix);
|
||||
bool has_mad = mad_mix ? !ctx.program->dev.fused_mad_mix
|
||||
: ((mad32 && ctx.program->chip_class < GFX10_3) ||
|
||||
(mad16 && ctx.program->chip_class <= GFX9));
|
||||
bool can_use_fma = has_fma && !info.instr->definitions[0].isPrecise() &&
|
||||
!instr->definitions[0].isPrecise();
|
||||
bool can_use_mad =
|
||||
has_mad && (mad32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == 0;
|
||||
has_mad && (mad_mix || mad32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == 0;
|
||||
if (mad_mix && legacy)
|
||||
continue;
|
||||
if (!can_use_fma && !can_use_mad)
|
||||
continue;
|
||||
|
||||
Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]};
|
||||
unsigned candidate_add_op_idx = is_add_mix ? (3 - i) : (1 - i);
|
||||
Operand op[3] = {info.instr->operands[0], info.instr->operands[1],
|
||||
instr->operands[candidate_add_op_idx]};
|
||||
if (info.instr->isSDWA() || info.instr->isDPP() || !check_vop3_operands(ctx, 3, op) ||
|
||||
ctx.uses[instr->operands[i].tempId()] > uses)
|
||||
continue;
|
||||
@@ -3740,7 +3760,7 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
}
|
||||
|
||||
mul_instr = info.instr;
|
||||
add_op_idx = 1 - i;
|
||||
add_op_idx = candidate_add_op_idx;
|
||||
uses = ctx.uses[instr->operands[i].tempId()];
|
||||
emit_fma = !can_use_mad;
|
||||
}
|
||||
@@ -3761,6 +3781,8 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
bool abs[3] = {false, false, false};
|
||||
unsigned omod = 0;
|
||||
bool clamp = false;
|
||||
uint8_t opsel_lo = 0;
|
||||
uint8_t opsel_hi = 0;
|
||||
|
||||
if (mul_instr->isVOP3()) {
|
||||
VOP3_instruction& vop3 = mul_instr->vop3();
|
||||
@@ -3768,6 +3790,14 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
neg[1] = vop3.neg[1];
|
||||
abs[0] = vop3.abs[0];
|
||||
abs[1] = vop3.abs[1];
|
||||
} else if (mul_instr->isVOP3P()) {
|
||||
VOP3P_instruction& vop3p = mul_instr->vop3p();
|
||||
neg[0] = vop3p.neg_lo[0];
|
||||
neg[1] = vop3p.neg_lo[1];
|
||||
abs[0] = vop3p.neg_hi[0];
|
||||
abs[1] = vop3p.neg_hi[1];
|
||||
opsel_lo = vop3p.opsel_lo & 0x3;
|
||||
opsel_hi = vop3p.opsel_hi & 0x3;
|
||||
}
|
||||
|
||||
if (instr->isVOP3()) {
|
||||
@@ -3785,41 +3815,79 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
}
|
||||
/* neg of the multiplication result */
|
||||
neg[1] = neg[1] ^ vop3.neg[1 - add_op_idx];
|
||||
} else if (instr->isVOP3P()) {
|
||||
VOP3P_instruction& vop3p = instr->vop3p();
|
||||
neg[2] = vop3p.neg_lo[add_op_idx];
|
||||
abs[2] = vop3p.neg_hi[add_op_idx];
|
||||
opsel_lo |= vop3p.opsel_lo & (1 << add_op_idx) ? 0x4 : 0x0;
|
||||
opsel_hi |= vop3p.opsel_hi & (1 << add_op_idx) ? 0x4 : 0x0;
|
||||
clamp = vop3p.clamp;
|
||||
/* abs of the multiplication result */
|
||||
if (vop3p.neg_hi[3 - add_op_idx]) {
|
||||
neg[0] = false;
|
||||
neg[1] = false;
|
||||
abs[0] = true;
|
||||
abs[1] = true;
|
||||
}
|
||||
/* neg of the multiplication result */
|
||||
neg[1] = neg[1] ^ vop3p.neg_lo[3 - add_op_idx];
|
||||
}
|
||||
|
||||
if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == aco_opcode::v_sub_f16)
|
||||
neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
|
||||
else if (instr->opcode == aco_opcode::v_subrev_f32 ||
|
||||
instr->opcode == aco_opcode::v_subrev_f16)
|
||||
neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
|
||||
|
||||
aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
|
||||
if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
|
||||
assert(emit_fma == (ctx.program->chip_class >= GFX10_3));
|
||||
mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
|
||||
} else if (mad16) {
|
||||
mad_op = emit_fma ? (ctx.program->chip_class == GFX8 ? aco_opcode::v_fma_legacy_f16
|
||||
: aco_opcode::v_fma_f16)
|
||||
: (ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_f16
|
||||
: aco_opcode::v_mad_f16);
|
||||
} else if (mad64) {
|
||||
mad_op = aco_opcode::v_fma_f64;
|
||||
}
|
||||
aco_ptr<Instruction> add_instr = std::move(instr);
|
||||
if (add_instr->isVOP3P() || mul_instr->isVOP3P()) {
|
||||
assert(!omod);
|
||||
|
||||
aco_ptr<VOP3_instruction> mad{
|
||||
create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
|
||||
for (unsigned i = 0; i < 3; i++) {
|
||||
mad->operands[i] = op[i];
|
||||
mad->neg[i] = neg[i];
|
||||
mad->abs[i] = abs[i];
|
||||
aco_opcode mad_op = add_instr->definitions[0].bytes() == 2 ? aco_opcode::v_fma_mixlo_f16
|
||||
: aco_opcode::v_fma_mix_f32;
|
||||
aco_ptr<VOP3P_instruction> mad{
|
||||
create_instruction<VOP3P_instruction>(mad_op, Format::VOP3P, 3, 1)};
|
||||
for (unsigned i = 0; i < 3; i++) {
|
||||
mad->operands[i] = op[i];
|
||||
mad->neg_lo[i] = neg[i];
|
||||
mad->neg_hi[i] = abs[i];
|
||||
}
|
||||
mad->clamp = clamp;
|
||||
mad->opsel_lo = opsel_lo;
|
||||
mad->opsel_hi = opsel_hi;
|
||||
|
||||
instr = std::move(mad);
|
||||
} else {
|
||||
aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
|
||||
if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
|
||||
assert(emit_fma == (ctx.program->chip_class >= GFX10_3));
|
||||
mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
|
||||
} else if (mad16) {
|
||||
mad_op = emit_fma ? (ctx.program->chip_class == GFX8 ? aco_opcode::v_fma_legacy_f16
|
||||
: aco_opcode::v_fma_f16)
|
||||
: (ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_f16
|
||||
: aco_opcode::v_mad_f16);
|
||||
} else if (mad64) {
|
||||
mad_op = aco_opcode::v_fma_f64;
|
||||
}
|
||||
|
||||
aco_ptr<VOP3_instruction> mad{
|
||||
create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
|
||||
for (unsigned i = 0; i < 3; i++) {
|
||||
mad->operands[i] = op[i];
|
||||
mad->neg[i] = neg[i];
|
||||
mad->abs[i] = abs[i];
|
||||
}
|
||||
mad->omod = omod;
|
||||
mad->clamp = clamp;
|
||||
|
||||
instr = std::move(mad);
|
||||
}
|
||||
mad->omod = omod;
|
||||
mad->clamp = clamp;
|
||||
mad->definitions[0] = instr->definitions[0];
|
||||
instr->definitions[0] = add_instr->definitions[0];
|
||||
|
||||
/* mark this ssa_def to be re-checked for profitability and literals */
|
||||
ctx.mad_infos.emplace_back(std::move(instr), mul_instr->definitions[0].tempId());
|
||||
ctx.info[mad->definitions[0].tempId()].set_mad(mad.get(), ctx.mad_infos.size() - 1);
|
||||
instr = std::move(mad);
|
||||
ctx.mad_infos.emplace_back(std::move(add_instr), mul_instr->definitions[0].tempId());
|
||||
ctx.info[instr->definitions[0].tempId()].set_mad(instr.get(), ctx.mad_infos.size() - 1);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -4084,7 +4152,8 @@ select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
mad_info = NULL;
|
||||
}
|
||||
/* check literals */
|
||||
else if (!instr->usesModifiers() && instr->opcode != aco_opcode::v_fma_f64 &&
|
||||
else if (!instr->usesModifiers() && !instr->isVOP3P() &&
|
||||
instr->opcode != aco_opcode::v_fma_f64 &&
|
||||
instr->opcode != aco_opcode::v_mad_legacy_f32 &&
|
||||
instr->opcode != aco_opcode::v_fma_legacy_f32) {
|
||||
/* FMA can only take literals on GFX10+ */
|
||||
|
Reference in New Issue
Block a user