diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 07a31319094..979a988edbc 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -782,29 +782,36 @@ void emit_vop2_instruction_logic64(isel_context *ctx, nir_alu_instr *instr, } void emit_vop3a_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst, - bool flush_denorms = false) + bool flush_denorms = false, unsigned num_sources = 2) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = get_alu_src(ctx, instr->src[1]); - Temp src2 = get_alu_src(ctx, instr->src[2]); - - /* ensure that the instruction has at most 1 sgpr operand - * The optimizer will inline constants for us */ - if (src0.type() == RegType::sgpr && src1.type() == RegType::sgpr) - src0 = as_vgpr(ctx, src0); - if (src1.type() == RegType::sgpr && src2.type() == RegType::sgpr) - src1 = as_vgpr(ctx, src1); - if (src2.type() == RegType::sgpr && src0.type() == RegType::sgpr) - src2 = as_vgpr(ctx, src2); + assert(num_sources == 2 || num_sources == 3); + Temp src[3] = { Temp(0, v1), Temp(0, v1), Temp(0, v1) }; + bool has_sgpr = false; + for (unsigned i = 0; i < num_sources; i++) { + src[i] = get_alu_src(ctx, instr->src[i]); + if (has_sgpr) + src[i] = as_vgpr(ctx, src[i]); + else + has_sgpr = src[i].type() == RegType::sgpr; + } Builder bld(ctx->program, ctx->block); bld.is_precise = instr->exact; if (flush_denorms && ctx->program->chip_class < GFX9) { assert(dst.size() == 1); - Temp tmp = bld.vop3(op, Definition(dst), src0, src1, src2); - bld.vop2(aco_opcode::v_mul_f32, Definition(dst), Operand(0x3f800000u), tmp); + Temp tmp; + if (num_sources == 3) + tmp = bld.vop3(op, bld.def(dst.regClass()), src[0], src[1], src[2]); + else + tmp = bld.vop3(op, bld.def(dst.regClass()), src[0], src[1]); + if (dst.size() == 1) + bld.vop2(aco_opcode::v_mul_f32, Definition(dst), Operand(0x3f800000u), tmp); + else + bld.vop3(aco_opcode::v_mul_f64, Definition(dst), Operand(0x3FF0000000000000lu), tmp); + } else if (num_sources == 3) { + bld.vop3(op, Definition(dst), src[0], src[1], src[2]); } else { - bld.vop3(op, Definition(dst), src0, src1, src2); + bld.vop3(op, Definition(dst), src[0], src[1]); } } @@ -1407,8 +1414,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) bld.vop3(aco_opcode::v_lshrrev_b64, Definition(dst), get_alu_src(ctx, instr->src[1]), get_alu_src(ctx, instr->src[0])); } else if (dst.regClass() == v2) { - bld.vop3(aco_opcode::v_lshr_b64, Definition(dst), - get_alu_src(ctx, instr->src[0]), get_alu_src(ctx, instr->src[1])); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_lshr_b64, dst); } else if (dst.regClass() == s2) { emit_sop2_instruction(ctx, instr, aco_opcode::s_lshr_b64, dst, true); } else if (dst.regClass() == s1) { @@ -1425,8 +1431,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) bld.vop3(aco_opcode::v_lshlrev_b64, Definition(dst), get_alu_src(ctx, instr->src[1]), get_alu_src(ctx, instr->src[0])); } else if (dst.regClass() == v2) { - bld.vop3(aco_opcode::v_lshl_b64, Definition(dst), - get_alu_src(ctx, instr->src[0]), get_alu_src(ctx, instr->src[1])); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_lshl_b64, dst); } else if (dst.regClass() == s1) { emit_sop2_instruction(ctx, instr, aco_opcode::s_lshl_b32, dst, true); } else if (dst.regClass() == s2) { @@ -1443,8 +1448,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) bld.vop3(aco_opcode::v_ashrrev_i64, Definition(dst), get_alu_src(ctx, instr->src[1]), get_alu_src(ctx, instr->src[0])); } else if (dst.regClass() == v2) { - bld.vop3(aco_opcode::v_ashr_i64, Definition(dst), - get_alu_src(ctx, instr->src[0]), get_alu_src(ctx, instr->src[1])); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_ashr_i64, dst); } else if (dst.regClass() == s1) { emit_sop2_instruction(ctx, instr, aco_opcode::s_ashr_i32, dst, true); } else if (dst.regClass() == s2) { @@ -1672,8 +1676,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_imul: { if (dst.regClass() == v1) { - bld.vop3(aco_opcode::v_mul_lo_u32, Definition(dst), - get_alu_src(ctx, instr->src[0]), get_alu_src(ctx, instr->src[1])); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_lo_u32, dst); } else if (dst.regClass() == s1) { emit_sop2_instruction(ctx, instr, aco_opcode::s_mul_i32, dst, false); } else { @@ -1683,7 +1686,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_umul_high: { if (dst.regClass() == v1) { - bld.vop3(aco_opcode::v_mul_hi_u32, Definition(dst), get_alu_src(ctx, instr->src[0]), get_alu_src(ctx, instr->src[1])); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_hi_u32, dst); } else if (dst.regClass() == s1 && ctx->options->chip_class >= GFX9) { emit_sop2_instruction(ctx, instr, aco_opcode::s_mul_hi_u32, dst, false); } else if (dst.regClass() == s1) { @@ -1697,7 +1700,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_imul_high: { if (dst.regClass() == v1) { - bld.vop3(aco_opcode::v_mul_hi_i32, Definition(dst), get_alu_src(ctx, instr->src[0]), get_alu_src(ctx, instr->src[1])); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_hi_i32, dst); } else if (dst.regClass() == s1 && ctx->options->chip_class >= GFX9) { emit_sop2_instruction(ctx, instr, aco_opcode::s_mul_hi_i32, dst, false); } else if (dst.regClass() == s1) { @@ -1715,9 +1718,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f32, dst, true); } else if (dst.regClass() == v2) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); - bld.vop3(aco_opcode::v_mul_f64, Definition(dst), src0, src1); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_f64, dst); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } @@ -1729,9 +1730,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f32, dst, true); } else if (dst.regClass() == v2) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); - bld.vop3(aco_opcode::v_add_f64, Definition(dst), src0, src1); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_add_f64, dst); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } @@ -1767,14 +1766,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32); } else if (dst.regClass() == v2) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); - if (ctx->block->fp_mode.must_flush_denorms16_64 && ctx->program->chip_class < GFX9) { - Temp tmp = bld.vop3(aco_opcode::v_max_f64, bld.def(v2), src0, src1); - bld.vop3(aco_opcode::v_mul_f64, Definition(dst), Operand(0x3FF0000000000000lu), tmp); - } else { - bld.vop3(aco_opcode::v_max_f64, Definition(dst), src0, src1); - } + emit_vop3a_instruction(ctx, instr, aco_opcode::v_max_f64, dst, ctx->block->fp_mode.must_flush_denorms16_64); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } @@ -1787,14 +1779,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32); } else if (dst.regClass() == v2) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); - if (ctx->block->fp_mode.must_flush_denorms16_64 && ctx->program->chip_class < GFX9) { - Temp tmp = bld.vop3(aco_opcode::v_min_f64, bld.def(v2), src0, src1); - bld.vop3(aco_opcode::v_mul_f64, Definition(dst), Operand(0x3FF0000000000000lu), tmp); - } else { - bld.vop3(aco_opcode::v_min_f64, Definition(dst), src0, src1); - } + emit_vop3a_instruction(ctx, instr, aco_opcode::v_min_f64, dst, ctx->block->fp_mode.must_flush_denorms16_64); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } @@ -2083,14 +2068,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_ldexp: { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = get_alu_src(ctx, instr->src[1]); if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_ldexp_f16, dst, false); } else if (dst.regClass() == v1) { - bld.vop3(aco_opcode::v_ldexp_f32, Definition(dst), as_vgpr(ctx, src0), src1); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_ldexp_f32, dst); } else if (dst.regClass() == v2) { - bld.vop3(aco_opcode::v_ldexp_f64, Definition(dst), as_vgpr(ctx, src0), src1); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_ldexp_f64, dst); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } @@ -2719,13 +2702,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_bitfield_select: { - /* (mask & insert) | (~mask & base) */ - Temp bitmask = get_alu_src(ctx, instr->src[0]); - Temp insert = get_alu_src(ctx, instr->src[1]); - Temp base = get_alu_src(ctx, instr->src[2]); /* dst = (insert & bitmask) | (base & ~bitmask) */ if (dst.regClass() == s1) { + Temp bitmask = get_alu_src(ctx, instr->src[0]); + Temp insert = get_alu_src(ctx, instr->src[1]); + Temp base = get_alu_src(ctx, instr->src[2]); aco_ptr sop2; nir_const_value* const_bitmask = nir_src_as_const_value(instr->src[0].src); nir_const_value* const_insert = nir_src_as_const_value(instr->src[1].src); @@ -2749,13 +2731,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) bld.sop2(aco_opcode::s_or_b32, Definition(dst), bld.def(s1, scc), rhs, lhs); } else if (dst.regClass() == v1) { - if (base.type() == RegType::sgpr && (bitmask.type() == RegType::sgpr || (insert.type() == RegType::sgpr))) - base = as_vgpr(ctx, base); - if (insert.type() == RegType::sgpr && bitmask.type() == RegType::sgpr) - insert = as_vgpr(ctx, insert); - - bld.vop3(aco_opcode::v_bfi_b32, Definition(dst), bitmask, insert, base); - + emit_vop3a_instruction(ctx, instr, aco_opcode::v_bfi_b32, dst, false, 3); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } @@ -2796,7 +2772,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } else { aco_opcode opcode = instr->op == nir_op_ubfe ? aco_opcode::v_bfe_u32 : aco_opcode::v_bfe_i32; - emit_vop3a_instruction(ctx, instr, opcode, dst); + emit_vop3a_instruction(ctx, instr, opcode, dst, false, 3); } break; }