diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index a775fd29edb..6e9ca7a69f5 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -122,12 +122,13 @@ enum Label { label_insert = 1ull << 34, label_dpp16 = 1ull << 35, label_dpp8 = 1ull << 36, + label_f2f32 = 1ull << 37, }; static constexpr uint64_t instr_usedef_labels = label_vec | label_mul | label_mad | label_add_sub | label_vop3p | label_bitwise | label_uniform_bitwise | label_minmax | label_vopc | label_usedef | label_extract | label_dpp16 | - label_dpp8; + label_dpp8 | label_f2f32; static constexpr uint64_t instr_mod_labels = label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert; @@ -441,6 +442,14 @@ struct ssa_info { bool is_canonicalized() { return label & label_canonicalized; } + void set_f2f32(Instruction* cvt) + { + add_label(label_f2f32); + instr = cvt; + } + + bool is_f2f32() { return label & label_f2f32; } + void set_extract(Instruction* extract) { add_label(label_extract); @@ -859,6 +868,9 @@ get_operand_size(aco_ptr& instr, unsigned index) else if (instr->opcode == aco_opcode::v_mad_u64_u32 || instr->opcode == aco_opcode::v_mad_i64_i32) return index == 2 ? 64 : 32; + else if (instr->opcode == aco_opcode::v_fma_mix_f32 || + instr->opcode == aco_opcode::v_fma_mixlo_f16) + return instr->vop3p().opsel_hi & (1u << index) ? 16 : 32; else if (instr->isVALU() || instr->isSALU()) return instr_info.operand_size[(int)instr->opcode]; else @@ -1075,9 +1087,11 @@ apply_extract(opt_ctx& ctx, aco_ptr& instr, unsigned idx, ssa_info& return; } - /* output modifier and label_vopc seem to be the only one worth keeping at the moment */ + /* Output modifier, label_vopc and label_f2f32 seem to be the only one worth keeping at the + * moment + */ for (Definition& def : instr->definitions) - ctx.info[def.tempId()].label &= (label_vopc | instr_mod_labels); + ctx.info[def.tempId()].label &= (label_vopc | label_f2f32 | instr_mod_labels); } void @@ -1875,6 +1889,11 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get()); break; } + case aco_opcode::v_cvt_f32_f16: { + if (instr->operands[0].isTemp()) + ctx.info[instr->definitions[0].tempId()].set_f2f32(instr.get()); + break; + } default: break; } @@ -3414,6 +3433,131 @@ combine_vop3p(opt_ctx& ctx, aco_ptr& instr) } } +bool +can_use_mad_mix(opt_ctx& ctx, aco_ptr& instr) +{ + if (ctx.program->chip_class < GFX9) + return false; + + switch (instr->opcode) { + case aco_opcode::v_add_f32: + case aco_opcode::v_sub_f32: + case aco_opcode::v_subrev_f32: + case aco_opcode::v_mul_f32: + case aco_opcode::v_fma_f32: break; + case aco_opcode::v_fma_mix_f32: + case aco_opcode::v_fma_mixlo_f16: return true; + default: return false; + } + + if (instr->opcode == aco_opcode::v_fma_f32 && !ctx.program->dev.fused_mad_mix && + instr->definitions[0].isPrecise()) + return false; + + if (instr->isVOP3()) + return !instr->vop3().omod && !(instr->vop3().opsel & 0x8); + + return instr->format == Format::VOP2; +} + +void +to_mad_mix(opt_ctx& ctx, aco_ptr& instr) +{ + bool is_add = instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32; + + aco_ptr vop3p{ + create_instruction(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1)}; + + vop3p->opsel_lo = instr->isVOP3() ? (instr->vop3().opsel & 0x7) << is_add : 0x0; + vop3p->opsel_hi = 0x0; + for (unsigned i = 0; i < instr->operands.size(); i++) { + vop3p->operands[is_add + i] = instr->operands[i]; + vop3p->neg_lo[is_add + i] = instr->isVOP3() && instr->vop3().neg[i]; + vop3p->neg_lo[is_add + i] |= instr->isSDWA() && instr->sdwa().neg[i]; + vop3p->neg_hi[is_add + i] = instr->isVOP3() && instr->vop3().abs[i]; + vop3p->neg_hi[is_add + i] |= instr->isSDWA() && instr->sdwa().abs[i]; + vop3p->opsel_lo |= (instr->isSDWA() && instr->sdwa().sel[i].offset()) << (is_add + i); + } + if (instr->opcode == aco_opcode::v_mul_f32) { + vop3p->opsel_hi &= 0x3; + vop3p->operands[2] = Operand::zero(); + vop3p->neg_lo[2] = true; + } else if (is_add) { + vop3p->opsel_hi &= 0x6; + vop3p->operands[0] = Operand::c32(0x3f800000); + if (instr->opcode == aco_opcode::v_sub_f32) + vop3p->neg_lo[2] ^= true; + else if (instr->opcode == aco_opcode::v_subrev_f32) + vop3p->neg_lo[1] ^= true; + } + vop3p->definitions[0] = instr->definitions[0]; + vop3p->clamp = instr->isVOP3() && instr->vop3().clamp; + instr = std::move(vop3p); + + ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_mul; + if (ctx.info[instr->definitions[0].tempId()].label & label_mul) + ctx.info[instr->definitions[0].tempId()].instr = instr.get(); +} + +void +combine_mad_mix(opt_ctx& ctx, aco_ptr& instr) +{ + if (!can_use_mad_mix(ctx, instr)) + return; + + for (unsigned i = 0; i < instr->operands.size(); i++) { + if (!instr->operands[i].isTemp()) + continue; + Temp tmp = instr->operands[i].getTemp(); + if (!ctx.info[tmp.id()].is_f2f32()) + continue; + + Instruction* conv = ctx.info[tmp.id()].instr; + if (conv->isSDWA() && (conv->sdwa().dst_sel.size() != 4 || conv->sdwa().sel[0].size() != 2 || + conv->sdwa().clamp || conv->sdwa().omod)) { + continue; + } else if (conv->isVOP3() && (conv->vop3().clamp || conv->vop3().omod)) { + continue; + } else if (conv->isDPP()) { + continue; + } + + if (get_operand_size(instr, i) != 32) + continue; + + /* Conversion to VOP3P will add inline constant operands, but that shouldn't affect + * check_vop3_operands(). */ + Operand op[3]; + for (unsigned j = 0; j < instr->operands.size(); j++) + op[j] = instr->operands[j]; + op[i] = conv->operands[0]; + if (!check_vop3_operands(ctx, instr->operands.size(), op)) + continue; + + if (!instr->isVOP3P()) { + bool is_add = + instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32; + to_mad_mix(ctx, instr); + i += is_add; + } + + if (--ctx.uses[tmp.id()]) + ctx.uses[conv->operands[0].tempId()]++; + instr->operands[i].setTemp(conv->operands[0].getTemp()); + if (conv->definitions[0].isPrecise()) + instr->definitions[0].setPrecise(true); + instr->vop3p().opsel_hi ^= 1u << i; + if (conv->isSDWA() && conv->sdwa().sel[0].offset() == 2) + instr->vop3p().opsel_lo |= 1u << i; + bool neg = (conv->isVOP3() && conv->vop3().neg[0]) || (conv->isSDWA() && conv->sdwa().neg[0]); + bool abs = (conv->isVOP3() && conv->vop3().abs[0]) || (conv->isSDWA() && conv->sdwa().abs[0]); + if (!instr->vop3p().neg_hi[i]) { + instr->vop3p().neg_lo[i] ^= neg; + instr->vop3p().neg_hi[i] = abs; + } + } +} + // TODO: we could possibly move the whole label_instruction pass to combine_instruction: // this would mean that we'd have to fix the instruction uses while value propagation @@ -3453,12 +3597,14 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) if (can_apply_sgprs(ctx, instr)) apply_sgprs(ctx, instr); + combine_mad_mix(ctx, instr); while (apply_omod_clamp(ctx, instr)) ; apply_insert(ctx, instr); } - if (instr->isVOP3P()) + if (instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 && + instr->opcode != aco_opcode::v_fma_mixlo_f16) return combine_vop3p(ctx, instr); if (ctx.info[instr->definitions[0].tempId()].is_vcc_hint()) { @@ -3504,7 +3650,7 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) return; if (mul_instr->isVOP3() && mul_instr->vop3().clamp) return; - if (mul_instr->isSDWA() || mul_instr->isDPP()) + if (mul_instr->isSDWA() || mul_instr->isDPP() || mul_instr->isVOP3P()) return; if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32 && ctx.fp_mode.preserve_signed_zero_inf_nan32) @@ -3562,6 +3708,8 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) /* no clamp/omod allowed between mul and add */ if (info.instr->isVOP3() && (info.instr->vop3().clamp || info.instr->vop3().omod)) continue; + if (info.instr->isVOP3P()) + continue; if (get_operand_size(instr, i) != info.instr->definitions[0].bytes() * 8) continue;