aco: use v_fma_mix to combine mul/add/fma input conversions
fossil-db (Sienna Cichlid): Totals from 11558 (8.57% of 134913) affected shaders: VGPRs: 829392 -> 825200 (-0.51%); split: -0.52%, +0.02% SpillSGPRs: 7845 -> 8399 (+7.06%) CodeSize: 101822704 -> 101677172 (-0.14%); split: -0.25%, +0.11% MaxWaves: 172216 -> 173182 (+0.56%); split: +0.59%, -0.03% Instrs: 19061343 -> 18883450 (-0.93%); split: -0.93%, +0.00% Latency: 256011590 -> 255177378 (-0.33%); split: -0.39%, +0.06% InvThroughput: 46104438 -> 45604059 (-1.09%); split: -1.12%, +0.04% VClause: 352211 -> 351948 (-0.07%); split: -0.21%, +0.13% SClause: 676506 -> 676961 (+0.07%); split: -0.04%, +0.11% Copies: 1246571 -> 1237745 (-0.71%); split: -0.97%, +0.26% Branches: 626229 -> 626241 (+0.00%); split: -0.02%, +0.03% PreSGPRs: 882176 -> 888853 (+0.76%); split: -0.00%, +0.76% PreVGPRs: 796705 -> 792304 (-0.55%); split: -0.56%, +0.00% fossil-db (Navi): Totals from 11558 (8.57% of 134913) affected shaders: VGPRs: 803900 -> 798660 (-0.65%); split: -0.73%, +0.08% SpillSGPRs: 7894 -> 8492 (+7.58%); split: -0.10%, +7.68% CodeSize: 96892596 -> 97134716 (+0.25%); split: -0.05%, +0.29% MaxWaves: 181454 -> 183014 (+0.86%); split: +0.94%, -0.08% Instrs: 18186813 -> 18093994 (-0.51%); split: -0.56%, +0.05% Latency: 253385909 -> 253325528 (-0.02%); split: -0.15%, +0.12% InvThroughput: 43315355 -> 42805541 (-1.18%); split: -1.33%, +0.15% VClause: 338755 -> 338535 (-0.06%); split: -0.16%, +0.10% SClause: 656561 -> 656829 (+0.04%); split: -0.07%, +0.11% Copies: 1162235 -> 1153558 (-0.75%); split: -1.07%, +0.32% Branches: 588536 -> 588542 (+0.00%); split: -0.03%, +0.03% PreSGPRs: 854849 -> 861640 (+0.79%); split: -0.00%, +0.80% PreVGPRs: 783401 -> 779031 (-0.56%); split: -0.56%, +0.00% fossil-db (Vega): Totals from 11516 (8.53% of 135048) affected shaders: SGPRs: 1072128 -> 1076288 (+0.39%); split: -0.01%, +0.40% VGPRs: 821312 -> 818124 (-0.39%); split: -0.43%, +0.04% SpillSGPRs: 11952 -> 12677 (+6.07%) CodeSize: 96378496 -> 96707596 (+0.34%); split: -0.04%, +0.38% MaxWaves: 42614 -> 42883 (+0.63%); split: +0.68%, -0.04% Instrs: 18672844 -> 18600274 (-0.39%); split: -0.44%, +0.05% Latency: 296658786 -> 296338296 (-0.11%); split: -0.21%, +0.10% InvThroughput: 111665547 -> 111283559 (-0.34%); split: -0.40%, +0.06% VClause: 343001 -> 342826 (-0.05%); split: -0.14%, +0.09% SClause: 646684 -> 646657 (-0.00%); split: -0.05%, +0.04% Copies: 1715316 -> 1712895 (-0.14%); split: -0.53%, +0.39% PreSGPRs: 850737 -> 856543 (+0.68%); split: -0.04%, +0.72% PreVGPRs: 775293 -> 772215 (-0.40%); split: -0.41%, +0.02% Signed-off-by: Rhys Perry <pendingchaos02@gmail.com> Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14769>
This commit is contained in:
@@ -122,12 +122,13 @@ enum Label {
|
||||
label_insert = 1ull << 34,
|
||||
label_dpp16 = 1ull << 35,
|
||||
label_dpp8 = 1ull << 36,
|
||||
label_f2f32 = 1ull << 37,
|
||||
};
|
||||
|
||||
static constexpr uint64_t instr_usedef_labels =
|
||||
label_vec | label_mul | label_mad | label_add_sub | label_vop3p | label_bitwise |
|
||||
label_uniform_bitwise | label_minmax | label_vopc | label_usedef | label_extract | label_dpp16 |
|
||||
label_dpp8;
|
||||
label_dpp8 | label_f2f32;
|
||||
static constexpr uint64_t instr_mod_labels =
|
||||
label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert;
|
||||
|
||||
@@ -441,6 +442,14 @@ struct ssa_info {
|
||||
|
||||
bool is_canonicalized() { return label & label_canonicalized; }
|
||||
|
||||
void set_f2f32(Instruction* cvt)
|
||||
{
|
||||
add_label(label_f2f32);
|
||||
instr = cvt;
|
||||
}
|
||||
|
||||
bool is_f2f32() { return label & label_f2f32; }
|
||||
|
||||
void set_extract(Instruction* extract)
|
||||
{
|
||||
add_label(label_extract);
|
||||
@@ -859,6 +868,9 @@ get_operand_size(aco_ptr<Instruction>& instr, unsigned index)
|
||||
else if (instr->opcode == aco_opcode::v_mad_u64_u32 ||
|
||||
instr->opcode == aco_opcode::v_mad_i64_i32)
|
||||
return index == 2 ? 64 : 32;
|
||||
else if (instr->opcode == aco_opcode::v_fma_mix_f32 ||
|
||||
instr->opcode == aco_opcode::v_fma_mixlo_f16)
|
||||
return instr->vop3p().opsel_hi & (1u << index) ? 16 : 32;
|
||||
else if (instr->isVALU() || instr->isSALU())
|
||||
return instr_info.operand_size[(int)instr->opcode];
|
||||
else
|
||||
@@ -1075,9 +1087,11 @@ apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info&
|
||||
return;
|
||||
}
|
||||
|
||||
/* output modifier and label_vopc seem to be the only one worth keeping at the moment */
|
||||
/* Output modifier, label_vopc and label_f2f32 seem to be the only one worth keeping at the
|
||||
* moment
|
||||
*/
|
||||
for (Definition& def : instr->definitions)
|
||||
ctx.info[def.tempId()].label &= (label_vopc | instr_mod_labels);
|
||||
ctx.info[def.tempId()].label &= (label_vopc | label_f2f32 | instr_mod_labels);
|
||||
}
|
||||
|
||||
void
|
||||
@@ -1875,6 +1889,11 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
|
||||
break;
|
||||
}
|
||||
case aco_opcode::v_cvt_f32_f16: {
|
||||
if (instr->operands[0].isTemp())
|
||||
ctx.info[instr->definitions[0].tempId()].set_f2f32(instr.get());
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
|
||||
@@ -3414,6 +3433,131 @@ combine_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
}
|
||||
}
|
||||
|
||||
bool
|
||||
can_use_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
{
|
||||
if (ctx.program->chip_class < GFX9)
|
||||
return false;
|
||||
|
||||
switch (instr->opcode) {
|
||||
case aco_opcode::v_add_f32:
|
||||
case aco_opcode::v_sub_f32:
|
||||
case aco_opcode::v_subrev_f32:
|
||||
case aco_opcode::v_mul_f32:
|
||||
case aco_opcode::v_fma_f32: break;
|
||||
case aco_opcode::v_fma_mix_f32:
|
||||
case aco_opcode::v_fma_mixlo_f16: return true;
|
||||
default: return false;
|
||||
}
|
||||
|
||||
if (instr->opcode == aco_opcode::v_fma_f32 && !ctx.program->dev.fused_mad_mix &&
|
||||
instr->definitions[0].isPrecise())
|
||||
return false;
|
||||
|
||||
if (instr->isVOP3())
|
||||
return !instr->vop3().omod && !(instr->vop3().opsel & 0x8);
|
||||
|
||||
return instr->format == Format::VOP2;
|
||||
}
|
||||
|
||||
void
|
||||
to_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
{
|
||||
bool is_add = instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
|
||||
|
||||
aco_ptr<VOP3P_instruction> vop3p{
|
||||
create_instruction<VOP3P_instruction>(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1)};
|
||||
|
||||
vop3p->opsel_lo = instr->isVOP3() ? (instr->vop3().opsel & 0x7) << is_add : 0x0;
|
||||
vop3p->opsel_hi = 0x0;
|
||||
for (unsigned i = 0; i < instr->operands.size(); i++) {
|
||||
vop3p->operands[is_add + i] = instr->operands[i];
|
||||
vop3p->neg_lo[is_add + i] = instr->isVOP3() && instr->vop3().neg[i];
|
||||
vop3p->neg_lo[is_add + i] |= instr->isSDWA() && instr->sdwa().neg[i];
|
||||
vop3p->neg_hi[is_add + i] = instr->isVOP3() && instr->vop3().abs[i];
|
||||
vop3p->neg_hi[is_add + i] |= instr->isSDWA() && instr->sdwa().abs[i];
|
||||
vop3p->opsel_lo |= (instr->isSDWA() && instr->sdwa().sel[i].offset()) << (is_add + i);
|
||||
}
|
||||
if (instr->opcode == aco_opcode::v_mul_f32) {
|
||||
vop3p->opsel_hi &= 0x3;
|
||||
vop3p->operands[2] = Operand::zero();
|
||||
vop3p->neg_lo[2] = true;
|
||||
} else if (is_add) {
|
||||
vop3p->opsel_hi &= 0x6;
|
||||
vop3p->operands[0] = Operand::c32(0x3f800000);
|
||||
if (instr->opcode == aco_opcode::v_sub_f32)
|
||||
vop3p->neg_lo[2] ^= true;
|
||||
else if (instr->opcode == aco_opcode::v_subrev_f32)
|
||||
vop3p->neg_lo[1] ^= true;
|
||||
}
|
||||
vop3p->definitions[0] = instr->definitions[0];
|
||||
vop3p->clamp = instr->isVOP3() && instr->vop3().clamp;
|
||||
instr = std::move(vop3p);
|
||||
|
||||
ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_mul;
|
||||
if (ctx.info[instr->definitions[0].tempId()].label & label_mul)
|
||||
ctx.info[instr->definitions[0].tempId()].instr = instr.get();
|
||||
}
|
||||
|
||||
void
|
||||
combine_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
{
|
||||
if (!can_use_mad_mix(ctx, instr))
|
||||
return;
|
||||
|
||||
for (unsigned i = 0; i < instr->operands.size(); i++) {
|
||||
if (!instr->operands[i].isTemp())
|
||||
continue;
|
||||
Temp tmp = instr->operands[i].getTemp();
|
||||
if (!ctx.info[tmp.id()].is_f2f32())
|
||||
continue;
|
||||
|
||||
Instruction* conv = ctx.info[tmp.id()].instr;
|
||||
if (conv->isSDWA() && (conv->sdwa().dst_sel.size() != 4 || conv->sdwa().sel[0].size() != 2 ||
|
||||
conv->sdwa().clamp || conv->sdwa().omod)) {
|
||||
continue;
|
||||
} else if (conv->isVOP3() && (conv->vop3().clamp || conv->vop3().omod)) {
|
||||
continue;
|
||||
} else if (conv->isDPP()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (get_operand_size(instr, i) != 32)
|
||||
continue;
|
||||
|
||||
/* Conversion to VOP3P will add inline constant operands, but that shouldn't affect
|
||||
* check_vop3_operands(). */
|
||||
Operand op[3];
|
||||
for (unsigned j = 0; j < instr->operands.size(); j++)
|
||||
op[j] = instr->operands[j];
|
||||
op[i] = conv->operands[0];
|
||||
if (!check_vop3_operands(ctx, instr->operands.size(), op))
|
||||
continue;
|
||||
|
||||
if (!instr->isVOP3P()) {
|
||||
bool is_add =
|
||||
instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
|
||||
to_mad_mix(ctx, instr);
|
||||
i += is_add;
|
||||
}
|
||||
|
||||
if (--ctx.uses[tmp.id()])
|
||||
ctx.uses[conv->operands[0].tempId()]++;
|
||||
instr->operands[i].setTemp(conv->operands[0].getTemp());
|
||||
if (conv->definitions[0].isPrecise())
|
||||
instr->definitions[0].setPrecise(true);
|
||||
instr->vop3p().opsel_hi ^= 1u << i;
|
||||
if (conv->isSDWA() && conv->sdwa().sel[0].offset() == 2)
|
||||
instr->vop3p().opsel_lo |= 1u << i;
|
||||
bool neg = (conv->isVOP3() && conv->vop3().neg[0]) || (conv->isSDWA() && conv->sdwa().neg[0]);
|
||||
bool abs = (conv->isVOP3() && conv->vop3().abs[0]) || (conv->isSDWA() && conv->sdwa().abs[0]);
|
||||
if (!instr->vop3p().neg_hi[i]) {
|
||||
instr->vop3p().neg_lo[i] ^= neg;
|
||||
instr->vop3p().neg_hi[i] = abs;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: we could possibly move the whole label_instruction pass to combine_instruction:
|
||||
// this would mean that we'd have to fix the instruction uses while value propagation
|
||||
|
||||
@@ -3453,12 +3597,14 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
|
||||
if (can_apply_sgprs(ctx, instr))
|
||||
apply_sgprs(ctx, instr);
|
||||
combine_mad_mix(ctx, instr);
|
||||
while (apply_omod_clamp(ctx, instr))
|
||||
;
|
||||
apply_insert(ctx, instr);
|
||||
}
|
||||
|
||||
if (instr->isVOP3P())
|
||||
if (instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
|
||||
instr->opcode != aco_opcode::v_fma_mixlo_f16)
|
||||
return combine_vop3p(ctx, instr);
|
||||
|
||||
if (ctx.info[instr->definitions[0].tempId()].is_vcc_hint()) {
|
||||
@@ -3504,7 +3650,7 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
return;
|
||||
if (mul_instr->isVOP3() && mul_instr->vop3().clamp)
|
||||
return;
|
||||
if (mul_instr->isSDWA() || mul_instr->isDPP())
|
||||
if (mul_instr->isSDWA() || mul_instr->isDPP() || mul_instr->isVOP3P())
|
||||
return;
|
||||
if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32 &&
|
||||
ctx.fp_mode.preserve_signed_zero_inf_nan32)
|
||||
@@ -3562,6 +3708,8 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
/* no clamp/omod allowed between mul and add */
|
||||
if (info.instr->isVOP3() && (info.instr->vop3().clamp || info.instr->vop3().omod))
|
||||
continue;
|
||||
if (info.instr->isVOP3P())
|
||||
continue;
|
||||
|
||||
if (get_operand_size(instr, i) != info.instr->definitions[0].bytes() * 8)
|
||||
continue;
|
||||
|
Reference in New Issue
Block a user