From 54292e99c7844500314bfd623469c65adef954c5 Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Wed, 12 Aug 2020 14:23:56 +0100 Subject: [PATCH] aco: optimize 32-bit extracts and inserts using SDWA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Still need to use dst_u=preserve field to optimize packs fossil-db (Sienna Cichlid): Totals from 15974 (10.66% of 149839) affected shaders: VGPRs: 1009064 -> 1008968 (-0.01%); split: -0.03%, +0.02% SpillSGPRs: 7959 -> 7964 (+0.06%) CodeSize: 101716436 -> 101159568 (-0.55%); split: -0.55%, +0.01% MaxWaves: 284464 -> 284490 (+0.01%); split: +0.02%, -0.01% Instrs: 19334216 -> 19224241 (-0.57%); split: -0.57%, +0.00% Latency: 375465295 -> 375230478 (-0.06%); split: -0.14%, +0.08% InvThroughput: 79006105 -> 78860705 (-0.18%); split: -0.25%, +0.07% fossil-db (Polaris): Totals from 11369 (7.51% of 151365) affected shaders: SGPRs: 787920 -> 787680 (-0.03%); split: -0.04%, +0.01% VGPRs: 681056 -> 681040 (-0.00%); split: -0.01%, +0.00% CodeSize: 68127288 -> 67664120 (-0.68%); split: -0.69%, +0.01% MaxWaves: 54370 -> 54371 (+0.00%) Instrs: 13294638 -> 13214109 (-0.61%); split: -0.62%, +0.01% Latency: 373515759 -> 373214571 (-0.08%); split: -0.11%, +0.03% InvThroughput: 166529524 -> 166275291 (-0.15%); split: -0.20%, +0.05% Signed-off-by: Rhys Perry Reviewed-by: Timur Kristóf Part-of: --- src/amd/compiler/aco_ir.cpp | 8 +- src/amd/compiler/aco_ir.h | 2 +- src/amd/compiler/aco_optimizer.cpp | 276 ++++++++++++++++++- src/amd/compiler/aco_register_allocation.cpp | 8 +- 4 files changed, 271 insertions(+), 23 deletions(-) diff --git a/src/amd/compiler/aco_ir.cpp b/src/amd/compiler/aco_ir.cpp index 5f6eb5c177d..da53b2b2914 100644 --- a/src/amd/compiler/aco_ir.cpp +++ b/src/amd/compiler/aco_ir.cpp @@ -196,7 +196,7 @@ memory_sync_info get_sync_info(const Instruction* instr) } } -bool can_use_SDWA(chip_class chip, const aco_ptr& instr) +bool can_use_SDWA(chip_class chip, const aco_ptr& instr, bool pre_ra) { if (!instr->isVALU()) return false; @@ -217,7 +217,7 @@ bool can_use_SDWA(chip_class chip, const aco_ptr& instr) return false; //TODO: return true if we know we will use vcc - if (instr->definitions.size() >= 2) + if (!pre_ra && instr->definitions.size() >= 2) return false; for (unsigned i = 1; i < instr->operands.size(); i++) { @@ -251,9 +251,9 @@ bool can_use_SDWA(chip_class chip, const aco_ptr& instr) return false; //TODO: return true if we know we will use vcc - if (instr->isVOPC()) + if (!pre_ra && instr->isVOPC()) return false; - if (instr->operands.size() >= 3 && !is_mac) + if (!pre_ra && instr->operands.size() >= 3 && !is_mac) return false; return instr->opcode != aco_opcode::v_madmk_f32 && diff --git a/src/amd/compiler/aco_ir.h b/src/amd/compiler/aco_ir.h index c05b6589ac3..824138f1148 100644 --- a/src/amd/compiler/aco_ir.h +++ b/src/amd/compiler/aco_ir.h @@ -1623,7 +1623,7 @@ memory_sync_info get_sync_info(const Instruction* instr); bool is_dead(const std::vector& uses, Instruction *instr); bool can_use_opsel(chip_class chip, aco_opcode op, int idx, bool high); -bool can_use_SDWA(chip_class chip, const aco_ptr& instr); +bool can_use_SDWA(chip_class chip, const aco_ptr& instr, bool pre_ra); /* updates "instr" and returns the old instruction (or NULL if no update was needed) */ aco_ptr convert_to_SDWA(chip_class chip, aco_ptr& instr); bool needs_exec_mask(const Instruction* instr); diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index b1fadd33c31..51cd347561e 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -119,11 +119,14 @@ enum Label { label_usedef = 1 << 30, /* generic label */ label_vop3p = 1ull << 31, /* 1ull to prevent sign extension */ label_canonicalized = 1ull << 32, + label_extract = 1ull << 33, + label_insert = 1ull << 34, }; static constexpr uint64_t instr_usedef_labels = label_vec | label_mul | label_mad | label_add_sub | label_vop3p | - label_bitwise | label_uniform_bitwise | label_minmax | label_vopc | label_usedef; -static constexpr uint64_t instr_mod_labels = label_omod2 | label_omod4 | label_omod5 | label_clamp; + label_bitwise | label_uniform_bitwise | label_minmax | label_vopc | + label_usedef | label_extract; +static constexpr uint64_t instr_mod_labels = label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert; static constexpr uint64_t instr_labels = instr_usedef_labels | instr_mod_labels; static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f | label_uniform_bool | @@ -535,6 +538,27 @@ struct ssa_info { return label & label_canonicalized; } + void set_extract(Instruction *extract) + { + add_label(label_extract); + instr = extract; + } + + bool is_extract() + { + return label & label_extract; + } + + void set_insert(Instruction *insert) + { + add_label(label_insert); + instr = insert; + } + + bool is_insert() + { + return label & label_insert; + } }; struct opt_ctx { @@ -745,6 +769,24 @@ void to_VOP3(opt_ctx& ctx, aco_ptr& instr) * been applied yet or this instruction isn't dead and so they've been ignored */ } +bool is_operand_vgpr(Operand op) +{ + return op.isTemp() && op.getTemp().type() == RegType::vgpr; +} + +void to_SDWA(opt_ctx& ctx, aco_ptr& instr) +{ + aco_ptr tmp = convert_to_SDWA(ctx.program->chip_class, instr); + if (!tmp) + return; + + for (unsigned i = 0; i < instr->definitions.size(); i++) { + ssa_info& info = ctx.info[instr->definitions[i].tempId()]; + if (info.label & instr_labels && info.instr == tmp.get()) + info.instr = instr.get(); + } +} + /* only covers special cases */ bool alu_can_accept_constant(aco_opcode opcode, unsigned operand) { @@ -903,6 +945,121 @@ bool fixed_to_exec(Operand op) return op.isFixed() && op.physReg() == exec; } +int parse_extract(Instruction *instr) +{ + if (instr->opcode == aco_opcode::p_extract) { + bool is_byte = instr->operands[2].constantEquals(8); + unsigned index = instr->operands[1].constantValue(); + unsigned sel = (is_byte ? sdwa_ubyte0 : sdwa_uword0) + index; + if (!instr->operands[3].constantEquals(0)) + sel |= sdwa_sext; + return sel; + } else if (instr->opcode == aco_opcode::p_insert && instr->operands[1].constantEquals(0)) { + return instr->operands[2].constantEquals(8) ? sdwa_ubyte0 : sdwa_uword0; + } else { + return -1; + } +} + +int parse_insert(Instruction *instr) +{ + if (instr->opcode == aco_opcode::p_extract && instr->operands[3].constantEquals(0) && + instr->operands[1].constantEquals(0)) { + return instr->operands[2].constantEquals(8) ? sdwa_ubyte0 : sdwa_uword0; + } else if (instr->opcode == aco_opcode::p_insert) { + bool is_byte = instr->operands[2].constantEquals(8); + unsigned index = instr->operands[1].constantValue(); + unsigned sel = (is_byte ? sdwa_ubyte0 : sdwa_uword0) + index; + return sel; + } else { + return -1; + } +} + +bool can_apply_extract(opt_ctx &ctx, aco_ptr& instr, unsigned idx, ssa_info& info) +{ + if (idx >= 2) + return false; + + Temp tmp = info.instr->operands[0].getTemp(); + unsigned sel = parse_extract(info.instr); + + if (sel == sdwa_udword || sel == sdwa_sdword) { + return true; + } else if (instr->opcode == aco_opcode::v_cvt_f32_u32 && sel <= sdwa_ubyte3) { + return true; + } else if (can_use_SDWA(ctx.program->chip_class, instr, true) && + (tmp.type() == RegType::vgpr || ctx.program->chip_class >= GFX9)) { + if (instr->isSDWA() && (static_cast(instr.get())->sel[idx] & sdwa_asuint) != sdwa_udword) + return false; + return true; + } else if (instr->isVOP3() && (sel & sdwa_isword) && + can_use_opsel(ctx.program->chip_class, instr->opcode, idx, (sel & sdwa_wordnum)) && + !(instr->vop3().opsel & (1 << idx))) { + return true; + } else { + return false; + } +} + +/* Combine an p_extract (or p_insert, in some cases) instruction with instr. + * instr(p_extract(...)) -> instr() + */ +void apply_extract(opt_ctx &ctx, aco_ptr& instr, unsigned idx, ssa_info& info) +{ + Temp tmp = info.instr->operands[0].getTemp(); + unsigned sel = parse_extract(info.instr); + + if (sel == sdwa_udword || sel == sdwa_sdword) { + } else if (instr->opcode == aco_opcode::v_cvt_f32_u32 && sel <= sdwa_ubyte3) { + switch (sel) { + case sdwa_ubyte0: + instr->opcode = aco_opcode::v_cvt_f32_ubyte0; + break; + case sdwa_ubyte1: + instr->opcode = aco_opcode::v_cvt_f32_ubyte1; + break; + case sdwa_ubyte2: + instr->opcode = aco_opcode::v_cvt_f32_ubyte2; + break; + case sdwa_ubyte3: + instr->opcode = aco_opcode::v_cvt_f32_ubyte3; + break; + } + } else if (can_use_SDWA(ctx.program->chip_class, instr, true) && + (tmp.type() == RegType::vgpr || ctx.program->chip_class >= GFX9)) { + to_SDWA(ctx, instr); + static_cast(instr.get())->sel[idx] = sel; + } else if (instr->isVOP3()) { + if (sel & sdwa_wordnum) + instr->vop3().opsel |= 1 << idx; + } + + ctx.info[tmp.id()].label &= ~label_insert; + /* label_vopc seems to be the only one worth keeping at the moment */ + for (Definition& def : instr->definitions) + ctx.info[def.tempId()].label &= label_vopc; +} + +void check_sdwa_extract(opt_ctx &ctx, aco_ptr& instr) +{ + /* only VALU can use SDWA */ + if (!instr->isVALU()) + return; + + for (unsigned i = 0; i < instr->operands.size(); i++) { + Operand op = instr->operands[i]; + if (!op.isTemp()) + continue; + ssa_info& info = ctx.info[op.tempId()]; + if (info.is_extract() && (info.instr->operands[0].getTemp().type() == RegType::vgpr || + op.getTemp().type() == RegType::sgpr)) { + if (!can_apply_extract(ctx, instr, i, info)) + info.label &= ~label_extract; + } + } +} + bool does_fp_op_flush_denorms(opt_ctx &ctx, aco_opcode op) { if (ctx.program->chip_class <= GFX8) { @@ -1200,8 +1357,10 @@ void label_instruction(opt_ctx &ctx, aco_ptr& instr) } /* if this instruction doesn't define anything, return */ - if (instr->definitions.empty()) + if (instr->definitions.empty()) { + check_sdwa_extract(ctx, instr); return; + } if (instr->isVALU() || instr->isVINTRP()) { if (instr_info.can_use_output_modifiers[(int)instr->opcode] || instr->isVINTRP() || @@ -1218,6 +1377,7 @@ void label_instruction(opt_ctx &ctx, aco_ptr& instr) if (instr->isVOPC()) { ctx.info[instr->definitions[0].tempId()].set_vopc(instr.get()); + check_sdwa_extract(ctx, instr); return; } if (instr->isVOP3P()) { @@ -1613,18 +1773,31 @@ void label_instruction(opt_ctx &ctx, aco_ptr& instr) ctx.info[instr->definitions[0].tempId()].set_canonicalized(); break; case aco_opcode::p_extract: { - if (instr->operands[0].isTemp()) - ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get()); + if (instr->definitions[0].bytes() == 4) { + ctx.info[instr->definitions[0].tempId()].set_extract(instr.get()); + if (instr->operands[0].regClass() == v1 && parse_insert(instr.get()) >= 0) + ctx.info[instr->operands[0].tempId()].set_insert(instr.get()); + } break; } case aco_opcode::p_insert: { - if (instr->operands[0].isTemp()) + if (instr->operands[0].bytes() == 4) { + if (instr->operands[0].regClass() == v1) + ctx.info[instr->operands[0].tempId()].set_insert(instr.get()); + if (parse_extract(instr.get()) >= 0) + ctx.info[instr->definitions[0].tempId()].set_extract(instr.get()); ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get()); + } break; } default: break; } + + /* Don't remove label_extract if we can't apply the extract to + * neg/abs instructions because we'll likely combine it into another valu. */ + if (!(ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs))) + check_sdwa_extract(ctx, instr); } ALWAYS_INLINE bool get_cmp_info(aco_opcode op, CmpInfo *info) @@ -1962,7 +2135,7 @@ bool combine_constant_comparison_ordering(opt_ctx &ctx, aco_ptr& in Instruction *nan_test = follow_operand(ctx, instr->operands[0], true); Instruction *cmp = follow_operand(ctx, instr->operands[1], true); - if (!nan_test || !cmp) + if (!nan_test || !cmp || nan_test->isSDWA() || cmp->isSDWA()) return false; if (nan_test->isSDWA() || cmp->isSDWA()) return false; @@ -2288,6 +2461,7 @@ bool combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr& instr) bool combine_minmax(opt_ctx& ctx, aco_ptr& instr, aco_opcode opposite, aco_opcode minmax3) { + /* TODO: this can handle SDWA min/max instructions by using opsel */ if (combine_three_valu_op(ctx, instr, instr->opcode, minmax3, "012", 1 | 2)) return true; @@ -2698,6 +2872,8 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr& instr) ssa_info& info = ctx.info[instr->operands[i].tempId()]; if (is_copy_label(ctx, instr, info) && info.temp.type() == RegType::sgpr) operand_mask |= 1u << i; + if (info.is_extract() && info.instr->operands[0].getTemp().type() == RegType::sgpr) + operand_mask |= 1u << i; } unsigned max_sgprs = 1; if (ctx.program->chip_class >= GFX10 && !is_shift64) @@ -2723,20 +2899,26 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr& instr) } operand_mask &= ~(1u << sgpr_idx); + ssa_info& info = ctx.info[sgpr_info_id]; + /* Applying two sgprs require making it VOP3, so don't do it unless it's * definitively beneficial. * TODO: this is too conservative because later the use count could be reduced to 1 */ - if (num_sgprs && ctx.uses[sgpr_info_id] > 1 && + if (!info.is_extract() && num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3() && !instr->isSDWA() && instr->format != Format::VOP3P) break; - Temp sgpr = ctx.info[sgpr_info_id].temp; + Temp sgpr = info.is_extract() ? info.instr->operands[0].getTemp() : info.temp; bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1]; if (new_sgpr && num_sgprs >= max_sgprs) continue; - if (sgpr_idx == 0 || instr->isVOP3() || - instr->isSDWA() || instr->isVOP3P()) { + if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA() || instr->isVOP3P() || info.is_extract()) { + /* can_apply_extract() checks SGPR encoding restrictions */ + if (info.is_extract() && can_apply_extract(ctx, instr, sgpr_idx, info)) + apply_extract(ctx, instr, sgpr_idx, info); + else if (info.is_extract()) + continue; instr->operands[sgpr_idx] = Operand(sgpr); } else if (can_swap_operands(instr)) { instr->operands[sgpr_idx] = instr->operands[0]; @@ -2744,7 +2926,7 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr& instr) /* swap bits using a 4-entry LUT */ uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf; operand_mask = (operand_mask & ~0x3) | swapped; - } else if (can_use_VOP3(ctx, instr)) { + } else if (can_use_VOP3(ctx, instr) && !info.is_extract()) { to_VOP3(ctx, instr); instr->operands[sgpr_idx] = Operand(sgpr); } else { @@ -2755,6 +2937,11 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr& instr) sgpr_ids[num_sgprs++] = sgpr.id(); ctx.uses[sgpr_info_id]--; ctx.uses[sgpr.id()]++; + + /* TODO: handle when it's a VGPR */ + if ((ctx.info[sgpr.id()].label & (label_extract | label_temp)) && + ctx.info[sgpr.id()].temp.type() == RegType::sgpr) + operand_mask |= 1u << sgpr_idx; } } @@ -2819,7 +3006,51 @@ bool apply_omod_clamp(opt_ctx &ctx, aco_ptr& instr) } instr->definitions[0].swapTemp(def_info.instr->definitions[0]); - ctx.info[instr->definitions[0].tempId()].label &= label_clamp; + ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_insert; + ctx.uses[def_info.instr->definitions[0].tempId()]--; + + return true; +} + +/* Combine an p_insert (or p_extract, in some cases) instruction with instr. + * p_insert(instr(...)) -> instr_insert(). + */ +bool apply_insert(opt_ctx &ctx, aco_ptr& instr) +{ + if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1) + return false; + + ssa_info& def_info = ctx.info[instr->definitions[0].tempId()]; + if (!def_info.is_insert()) + return false; + /* if the insert instruction is dead, then the single user of this + * instruction is a different instruction */ + if (!ctx.uses[def_info.instr->definitions[0].tempId()]) + return false; + + /* MADs/FMAs are created later, so we don't have to update the original add */ + assert(!ctx.info[instr->definitions[0].tempId()].is_mad()); + + unsigned sel = parse_insert(def_info.instr); + + if (instr->isVOP3() && (sel & sdwa_isword) && !(sel & sdwa_sext) && + can_use_opsel(ctx.program->chip_class, instr->opcode, 3, (sel & sdwa_wordnum))) { + if (instr->vop3().opsel & (1 << 3)) + return false; + if (sel & sdwa_wordnum) + instr->vop3().opsel |= 1 << 3; + } else { + if (!can_use_SDWA(ctx.program->chip_class, instr, true)) + return false; + + to_SDWA(ctx, instr); + if ((static_cast(instr.get())->dst_sel & sdwa_asuint) != sdwa_udword) + return false; + static_cast(instr.get())->dst_sel = sel; + } + + instr->definitions[0].swapTemp(def_info.instr->definitions[0]); + ctx.info[instr->definitions[0].tempId()].label = 0; ctx.uses[def_info.instr->definitions[0].tempId()]--; return true; @@ -3077,9 +3308,26 @@ void combine_instruction(opt_ctx &ctx, aco_ptr& instr) return; if (instr->isVALU()) { + /* Apply SDWA. Do this after label_instruction() so it can remove + * label_extract if not all instructions can take SDWA. */ + for (unsigned i = 0; i < instr->operands.size(); i++) { + Operand& op = instr->operands[i]; + if (!op.isTemp()) + continue; + ssa_info& info = ctx.info[op.tempId()]; + if (info.is_extract() && (info.instr->operands[0].getTemp().type() == RegType::vgpr || + instr->operands[i].getTemp().type() == RegType::sgpr) && + can_apply_extract(ctx, instr, i, info)) { + apply_extract(ctx, instr, i, info); + ctx.uses[instr->operands[i].tempId()]--; + instr->operands[i].setTemp(info.instr->operands[0].getTemp()); + } + } + if (can_apply_sgprs(ctx, instr)) apply_sgprs(ctx, instr); while (apply_omod_clamp(ctx, instr)) ; + apply_insert(ctx, instr); } if (instr->isVOP3P()) @@ -3495,7 +3743,7 @@ void select_instruction(opt_ctx &ctx, aco_ptr& instr) if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) { mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].instr->pass_flags]; /* re-check mad instructions */ - if (ctx.uses[mad_info->mul_temp_id]) { + if (ctx.uses[mad_info->mul_temp_id] && mad_info->add_instr) { ctx.uses[mad_info->mul_temp_id]++; if (instr->operands[0].isTemp()) ctx.uses[instr->operands[0].tempId()]--; diff --git a/src/amd/compiler/aco_register_allocation.cpp b/src/amd/compiler/aco_register_allocation.cpp index 87f0a70dca5..ab643bfdc85 100644 --- a/src/amd/compiler/aco_register_allocation.cpp +++ b/src/amd/compiler/aco_register_allocation.cpp @@ -427,7 +427,7 @@ unsigned get_subdword_operand_stride(chip_class chip, const aco_ptr if (instr->opcode == aco_opcode::v_cvt_f32_ubyte0) { return 1; - } else if (can_use_SDWA(chip, instr)) { + } else if (can_use_SDWA(chip, instr, false)) { return rc.bytes() % 2 == 0 ? 2 : 1; } else if (rc.bytes() == 2 && can_use_opsel(chip, instr->opcode, idx, 1)) { return 2; @@ -479,7 +479,7 @@ void add_subdword_operand(ra_ctx& ctx, aco_ptr& instr, unsigned idx break; } return; - } else if (can_use_SDWA(chip, instr)) { + } else if (can_use_SDWA(chip, instr, false)) { aco_ptr tmp = convert_to_SDWA(chip, instr); return; } else if (rc.bytes() == 2 && can_use_opsel(chip, instr->opcode, idx, byte / 2)) { @@ -550,7 +550,7 @@ std::pair get_subdword_definition_info(Program *program, con bytes_written = bytes_written > 4 ? align(bytes_written, 4) : bytes_written; bytes_written = MAX2(bytes_written, instr_info.definition_size[(int)instr->opcode] / 8u); - if (can_use_SDWA(chip, instr)) { + if (can_use_SDWA(chip, instr, false)) { return std::make_pair(rc.bytes(), rc.bytes()); } else if (rc.bytes() == 2 && can_use_opsel(chip, instr->opcode, -1, 1)) { return std::make_pair(2u, bytes_written); @@ -587,7 +587,7 @@ void add_subdword_definition(Program *program, aco_ptr& instr, unsi if (instr->isPseudo()) { return; - } else if (can_use_SDWA(chip, instr)) { + } else if (can_use_SDWA(chip, instr, false)) { unsigned def_size = instr_info.definition_size[(int)instr->opcode]; if (reg.byte() || chip < GFX10 || def_size > rc.bytes() * 8u) convert_to_SDWA(chip, instr);