aco: combine add/mul as v_fma_mix into fma

fossil-db (Sienna Cichlid):
Totals from 7345 (5.44% of 134913) affected shaders:
CodeSize: 73840060 -> 73768936 (-0.10%); split: -0.10%, +0.00%
Instrs: 13701603 -> 13684183 (-0.13%); split: -0.13%, +0.00%
Latency: 185389373 -> 185306538 (-0.04%); split: -0.04%, +0.00%
InvThroughput: 33785020 -> 33757593 (-0.08%); split: -0.08%, +0.00%
VClause: 237337 -> 237338 (+0.00%)
SClause: 485728 -> 485720 (-0.00%)
Copies: 935900 -> 935279 (-0.07%); split: -0.07%, +0.00%
Branches: 480721 -> 480722 (+0.00%)

fossil-db (Navi):
Totals from 10649 (7.89% of 134913) affected shaders:
VGPRs: 756624 -> 756516 (-0.01%); split: -0.02%, +0.01%
CodeSize: 92156580 -> 91707900 (-0.49%); split: -0.49%, +0.00%
MaxWaves: 159402 -> 159476 (+0.05%); split: +0.07%, -0.02%
Instrs: 17155827 -> 17070449 (-0.50%); split: -0.50%, +0.00%
Latency: 246296456 -> 245487120 (-0.33%); split: -0.33%, +0.00%
InvThroughput: 41438159 -> 41117424 (-0.77%); split: -0.77%, +0.00%
VClause: 323790 -> 323867 (+0.02%); split: -0.00%, +0.03%
SClause: 612077 -> 612034 (-0.01%); split: -0.01%, +0.00%
Copies: 1103012 -> 1102775 (-0.02%); split: -0.03%, +0.01%
Branches: 555893 -> 555896 (+0.00%); split: -0.00%, +0.00%
PreSGPRs: 824372 -> 824378 (+0.00%)
PreVGPRs: 740390 -> 740363 (-0.00%); split: -0.01%, +0.01%

fossil-db (Vega):
Totals from 10950 (8.11% of 135048) affected shaders:
SGPRs: 1034528 -> 1034560 (+0.00%)
VGPRs: 794092 -> 794104 (+0.00%); split: -0.01%, +0.01%
CodeSize: 94409768 -> 93955568 (-0.48%); split: -0.48%, +0.00%
MaxWaves: 38950 -> 38939 (-0.03%); split: +0.00%, -0.03%
Instrs: 18162637 -> 18070934 (-0.50%); split: -0.51%, +0.00%
Latency: 291718455 -> 290772451 (-0.32%); split: -0.32%, +0.00%
InvThroughput: 109114674 -> 108489767 (-0.57%); split: -0.57%, +0.00%
VClause: 334498 -> 334579 (+0.02%); split: -0.01%, +0.03%
SClause: 628871 -> 628825 (-0.01%); split: -0.01%, +0.00%
Copies: 1674477 -> 1674850 (+0.02%); split: -0.02%, +0.04%
PreSGPRs: 834800 -> 834802 (+0.00%)
PreVGPRs: 750460 -> 750415 (-0.01%); split: -0.01%, +0.01%

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14769>
This commit is contained in:
Rhys Perry
2022-01-17 17:48:33 +00:00
committed by Marge Bot
parent 9934c86761
commit 35196b6d89

View File

@@ -3689,18 +3689,25 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
}
/* combine mul+add -> mad */
bool is_add_mix =
(instr->opcode == aco_opcode::v_fma_mix_f32 ||
instr->opcode == aco_opcode::v_fma_mixlo_f16) &&
!instr->vop3p().neg_lo[0] &&
((instr->operands[0].constantEquals(0x3f800000) && (instr->vop3p().opsel_hi & 0x1) == 0) ||
(instr->operands[0].constantEquals(0x3C00) && (instr->vop3p().opsel_hi & 0x1) &&
!(instr->vop3p().opsel_lo & 0x1)));
bool mad32 = instr->opcode == aco_opcode::v_add_f32 || instr->opcode == aco_opcode::v_sub_f32 ||
instr->opcode == aco_opcode::v_subrev_f32;
bool mad16 = instr->opcode == aco_opcode::v_add_f16 || instr->opcode == aco_opcode::v_sub_f16 ||
instr->opcode == aco_opcode::v_subrev_f16;
bool mad64 = instr->opcode == aco_opcode::v_add_f64;
if (mad16 || mad32 || mad64) {
if (is_add_mix || mad16 || mad32 || mad64) {
Instruction* mul_instr = nullptr;
unsigned add_op_idx = 0;
uint32_t uses = UINT32_MAX;
bool emit_fma = false;
/* find the 'best' mul instruction to combine with the add */
for (unsigned i = 0; i < 2; i++) {
for (unsigned i = is_add_mix ? 1 : 0; i < instr->operands.size(); i++) {
if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
continue;
ssa_info& info = ctx.info[instr->operands[i].tempId()];
@@ -3708,26 +3715,39 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
/* no clamp/omod allowed between mul and add */
if (info.instr->isVOP3() && (info.instr->vop3().clamp || info.instr->vop3().omod))
continue;
if (info.instr->isVOP3P())
if (info.instr->isVOP3P() && info.instr->vop3p().clamp)
continue;
/* v_fma_mix_f32/etc can't do omod */
if (info.instr->isVOP3P() && instr->isVOP3() && instr->vop3().omod)
continue;
/* don't promote fp16 to fp32 or remove fp32->fp16->fp32 conversions */
if (is_add_mix && info.instr->definitions[0].bytes() == 2)
continue;
if (get_operand_size(instr, i) != info.instr->definitions[0].bytes() * 8)
continue;
bool legacy = info.instr->opcode == aco_opcode::v_mul_legacy_f32;
bool mad_mix = is_add_mix || info.instr->isVOP3P();
bool has_fma = mad16 || mad64 || (legacy && ctx.program->chip_class >= GFX10_3) ||
(mad32 && !legacy && ctx.program->dev.has_fast_fma32);
bool has_mad = (mad32 && ctx.program->chip_class < GFX10_3) ||
(mad16 && ctx.program->chip_class <= GFX9);
(mad32 && !legacy && !mad_mix && ctx.program->dev.has_fast_fma32) ||
(mad_mix && ctx.program->dev.fused_mad_mix);
bool has_mad = mad_mix ? !ctx.program->dev.fused_mad_mix
: ((mad32 && ctx.program->chip_class < GFX10_3) ||
(mad16 && ctx.program->chip_class <= GFX9));
bool can_use_fma = has_fma && !info.instr->definitions[0].isPrecise() &&
!instr->definitions[0].isPrecise();
bool can_use_mad =
has_mad && (mad32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == 0;
has_mad && (mad_mix || mad32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == 0;
if (mad_mix && legacy)
continue;
if (!can_use_fma && !can_use_mad)
continue;
Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]};
unsigned candidate_add_op_idx = is_add_mix ? (3 - i) : (1 - i);
Operand op[3] = {info.instr->operands[0], info.instr->operands[1],
instr->operands[candidate_add_op_idx]};
if (info.instr->isSDWA() || info.instr->isDPP() || !check_vop3_operands(ctx, 3, op) ||
ctx.uses[instr->operands[i].tempId()] > uses)
continue;
@@ -3740,7 +3760,7 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
}
mul_instr = info.instr;
add_op_idx = 1 - i;
add_op_idx = candidate_add_op_idx;
uses = ctx.uses[instr->operands[i].tempId()];
emit_fma = !can_use_mad;
}
@@ -3761,6 +3781,8 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
bool abs[3] = {false, false, false};
unsigned omod = 0;
bool clamp = false;
uint8_t opsel_lo = 0;
uint8_t opsel_hi = 0;
if (mul_instr->isVOP3()) {
VOP3_instruction& vop3 = mul_instr->vop3();
@@ -3768,6 +3790,14 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
neg[1] = vop3.neg[1];
abs[0] = vop3.abs[0];
abs[1] = vop3.abs[1];
} else if (mul_instr->isVOP3P()) {
VOP3P_instruction& vop3p = mul_instr->vop3p();
neg[0] = vop3p.neg_lo[0];
neg[1] = vop3p.neg_lo[1];
abs[0] = vop3p.neg_hi[0];
abs[1] = vop3p.neg_hi[1];
opsel_lo = vop3p.opsel_lo & 0x3;
opsel_hi = vop3p.opsel_hi & 0x3;
}
if (instr->isVOP3()) {
@@ -3785,41 +3815,79 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
}
/* neg of the multiplication result */
neg[1] = neg[1] ^ vop3.neg[1 - add_op_idx];
} else if (instr->isVOP3P()) {
VOP3P_instruction& vop3p = instr->vop3p();
neg[2] = vop3p.neg_lo[add_op_idx];
abs[2] = vop3p.neg_hi[add_op_idx];
opsel_lo |= vop3p.opsel_lo & (1 << add_op_idx) ? 0x4 : 0x0;
opsel_hi |= vop3p.opsel_hi & (1 << add_op_idx) ? 0x4 : 0x0;
clamp = vop3p.clamp;
/* abs of the multiplication result */
if (vop3p.neg_hi[3 - add_op_idx]) {
neg[0] = false;
neg[1] = false;
abs[0] = true;
abs[1] = true;
}
/* neg of the multiplication result */
neg[1] = neg[1] ^ vop3p.neg_lo[3 - add_op_idx];
}
if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == aco_opcode::v_sub_f16)
neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
else if (instr->opcode == aco_opcode::v_subrev_f32 ||
instr->opcode == aco_opcode::v_subrev_f16)
neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
assert(emit_fma == (ctx.program->chip_class >= GFX10_3));
mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
} else if (mad16) {
mad_op = emit_fma ? (ctx.program->chip_class == GFX8 ? aco_opcode::v_fma_legacy_f16
: aco_opcode::v_fma_f16)
: (ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_f16
: aco_opcode::v_mad_f16);
} else if (mad64) {
mad_op = aco_opcode::v_fma_f64;
}
aco_ptr<Instruction> add_instr = std::move(instr);
if (add_instr->isVOP3P() || mul_instr->isVOP3P()) {
assert(!omod);
aco_ptr<VOP3_instruction> mad{
create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
for (unsigned i = 0; i < 3; i++) {
mad->operands[i] = op[i];
mad->neg[i] = neg[i];
mad->abs[i] = abs[i];
aco_opcode mad_op = add_instr->definitions[0].bytes() == 2 ? aco_opcode::v_fma_mixlo_f16
: aco_opcode::v_fma_mix_f32;
aco_ptr<VOP3P_instruction> mad{
create_instruction<VOP3P_instruction>(mad_op, Format::VOP3P, 3, 1)};
for (unsigned i = 0; i < 3; i++) {
mad->operands[i] = op[i];
mad->neg_lo[i] = neg[i];
mad->neg_hi[i] = abs[i];
}
mad->clamp = clamp;
mad->opsel_lo = opsel_lo;
mad->opsel_hi = opsel_hi;
instr = std::move(mad);
} else {
aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
assert(emit_fma == (ctx.program->chip_class >= GFX10_3));
mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
} else if (mad16) {
mad_op = emit_fma ? (ctx.program->chip_class == GFX8 ? aco_opcode::v_fma_legacy_f16
: aco_opcode::v_fma_f16)
: (ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_f16
: aco_opcode::v_mad_f16);
} else if (mad64) {
mad_op = aco_opcode::v_fma_f64;
}
aco_ptr<VOP3_instruction> mad{
create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
for (unsigned i = 0; i < 3; i++) {
mad->operands[i] = op[i];
mad->neg[i] = neg[i];
mad->abs[i] = abs[i];
}
mad->omod = omod;
mad->clamp = clamp;
instr = std::move(mad);
}
mad->omod = omod;
mad->clamp = clamp;
mad->definitions[0] = instr->definitions[0];
instr->definitions[0] = add_instr->definitions[0];
/* mark this ssa_def to be re-checked for profitability and literals */
ctx.mad_infos.emplace_back(std::move(instr), mul_instr->definitions[0].tempId());
ctx.info[mad->definitions[0].tempId()].set_mad(mad.get(), ctx.mad_infos.size() - 1);
instr = std::move(mad);
ctx.mad_infos.emplace_back(std::move(add_instr), mul_instr->definitions[0].tempId());
ctx.info[instr->definitions[0].tempId()].set_mad(instr.get(), ctx.mad_infos.size() - 1);
return;
}
}
@@ -4084,7 +4152,8 @@ select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
mad_info = NULL;
}
/* check literals */
else if (!instr->usesModifiers() && instr->opcode != aco_opcode::v_fma_f64 &&
else if (!instr->usesModifiers() && !instr->isVOP3P() &&
instr->opcode != aco_opcode::v_fma_f64 &&
instr->opcode != aco_opcode::v_mad_legacy_f32 &&
instr->opcode != aco_opcode::v_fma_legacy_f32) {
/* FMA can only take literals on GFX10+ */