diff --git a/src/amd/compiler/aco_insert_exec_mask.cpp b/src/amd/compiler/aco_insert_exec_mask.cpp index 37201622eef..c3a8f09f018 100644 --- a/src/amd/compiler/aco_insert_exec_mask.cpp +++ b/src/amd/compiler/aco_insert_exec_mask.cpp @@ -235,8 +235,7 @@ add_coupling_code(exec_ctx& ctx, Block* block, std::vector> assert(preds[0] == idx - 1); ctx.info[idx].exec = ctx.info[idx - 1].exec; loop_info& info = ctx.loop.back(); - while (ctx.info[idx].exec.size() > info.num_exec_masks) - ctx.info[idx].exec.pop_back(); + assert(ctx.info[idx].exec.size() == info.num_exec_masks); /* create ssa names for outer exec masks */ if (info.has_discard) { @@ -250,17 +249,6 @@ add_coupling_code(exec_ctx& ctx, Block* block, std::vector> } } - /* create ssa name for restore mask */ - if (info.has_divergent_break) { - // TODO: this phi is unnecessary if we end WQM immediately after the loop - /* this phi might be trivial but ensures a parallelcopy on the loop header */ - aco_ptr phi{create_instruction( - aco_opcode::p_linear_phi, Format::PSEUDO, preds.size(), 1)}; - phi->definitions[0] = bld.def(bld.lm); - phi->operands[0] = get_exec_op(ctx.info[preds[0]].exec[info.num_exec_masks - 1].first); - ctx.info[idx].exec.back().first = bld.insert(std::move(phi)); - } - /* create ssa name for loop active mask */ aco_ptr phi{create_instruction( aco_opcode::p_linear_phi, Format::PSEUDO, preds.size(), 1)}; @@ -269,16 +257,8 @@ add_coupling_code(exec_ctx& ctx, Block* block, std::vector> else phi->definitions[0] = Definition(exec, bld.lm); phi->operands[0] = get_exec_op(ctx.info[preds[0]].exec.back().first); - Temp loop_active = bld.insert(std::move(phi)); - - if (info.has_divergent_break) { - uint8_t mask_type = - (ctx.info[idx].exec.back().second & (mask_type_wqm | mask_type_exact)) | mask_type_loop; - ctx.info[idx].exec.emplace_back(loop_active, mask_type); - } else { - ctx.info[idx].exec.back().first = Operand(loop_active); - ctx.info[idx].exec.back().second |= mask_type_loop; - } + ctx.info[idx].exec.back().first = bld.insert(std::move(phi)); + ctx.info[idx].exec.back().second |= mask_type_loop; /* create a parallelcopy to move the active mask to exec */ if (info.has_divergent_continue) { @@ -318,13 +298,9 @@ add_coupling_code(exec_ctx& ctx, Block* block, std::vector> if (info.has_divergent_break) { restore_exec = true; - aco_ptr& phi = header->instructions[instr_idx]; - assert(phi->opcode == aco_opcode::p_linear_phi); - for (unsigned i = 1; i < phi->operands.size(); i++) - phi->operands[i] = - get_exec_op(ctx.info[header_preds[i]].exec[info.num_exec_masks].first); + /* Drop the loop active mask. */ + info.num_exec_masks--; } - assert(!(block->kind & block_kind_top_level) || info.num_exec_masks <= 2); /* create the loop exit phis if not trivial */ @@ -345,10 +321,6 @@ add_coupling_code(exec_ctx& ctx, Block* block, std::vector> aco_ptr phi{create_instruction( aco_opcode::p_linear_phi, Format::PSEUDO, preds.size(), 1)}; phi->definitions[0] = bld.def(bld.lm); - if (exec_idx == info.num_exec_masks - 1u) { - phi->definitions[0] = Definition(exec, bld.lm); - restore_exec = false; - } for (unsigned i = 0; i < phi->operands.size(); i++) phi->operands[i] = get_exec_op(ctx.info[preds[i]].exec[exec_idx].first); ctx.info[idx].exec.emplace_back(bld.insert(std::move(phi)), type); @@ -659,9 +631,20 @@ add_branch_code(exec_ctx& ctx, Block* block) has_divergent_continue = true; } + if (has_divergent_break) { + /* save restore exec mask */ + uint8_t mask = ctx.info[idx].exec.back().second; + if (ctx.info[idx].exec.back().first.constantEquals(-1u)) { + ctx.info[idx].exec.emplace_back(Operand(exec, bld.lm), mask); + } else { + bld.reset(bld.instructions, std::prev(bld.instructions->end())); + Operand restore = bld.copy(bld.def(bld.lm), Operand(exec, bld.lm)); + ctx.info[idx].exec.emplace(std::prev(ctx.info[idx].exec.end()), restore, mask); + bld.reset(bld.instructions); + } + ctx.info[idx].exec.back().second &= (mask_type_wqm | mask_type_exact); + } unsigned num_exec_masks = ctx.info[idx].exec.size(); - if (block->kind & block_kind_top_level) - num_exec_masks = std::min(num_exec_masks, 2u); ctx.loop.emplace_back(&ctx.program->blocks[block->linear_succs[0]], num_exec_masks, has_divergent_break, has_divergent_continue, has_discard);