diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 78bbb42a2f7..669209c268c 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -1152,67 +1152,84 @@ emit_bcsel(isel_context* ctx, nir_alu_instr* instr, Temp dst) } void -emit_scaled_op(isel_context* ctx, Builder& bld, Definition dst, Temp val, aco_opcode op, - uint32_t undo) +emit_scaled_op(isel_context* ctx, Builder& bld, Definition dst, Temp val, aco_opcode vop, + aco_opcode sop, uint32_t undo) { + if (ctx->block->fp_mode.denorm32 == 0) { + if (dst.regClass() == v1) + bld.vop1(vop, dst, val); + else if (ctx->options->gfx_level >= GFX12) + bld.vop3(sop, dst, val); + else + bld.pseudo(aco_opcode::p_as_uniform, dst, bld.vop1(vop, bld.def(v1), val)); + return; + } + /* multiply by 16777216 to handle denormals */ - Temp is_denormal = bld.tmp(bld.lm); - VALU_instruction& valu = - bld.vopc_e64(aco_opcode::v_cmp_class_f32, Definition(is_denormal), val, Operand::c32(1u << 4)) - ->valu(); - valu.neg[0] = true; - valu.abs[0] = true; - Temp scaled = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand::c32(0x4b800000u), val); - scaled = bld.vop1(op, bld.def(v1), scaled); - scaled = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand::c32(undo), scaled); + Temp scale, unscale; + if (val.regClass() == v1) { + val = as_vgpr(bld, val); + Temp is_denormal = bld.tmp(bld.lm); + VALU_instruction& valu = bld.vopc_e64(aco_opcode::v_cmp_class_f32, Definition(is_denormal), + val, Operand::c32(1u << 4)) + ->valu(); + valu.neg[0] = true; + valu.abs[0] = true; + scale = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand::c32(0x3f800000), + bld.copy(bld.def(s1), Operand::c32(0x4b800000u)), is_denormal); + unscale = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand::c32(0x3f800000), + bld.copy(bld.def(s1), Operand::c32(undo)), is_denormal); + } else { + Temp abs = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), val, + bld.copy(bld.def(s1), Operand::c32(0x7fffffff))); + Temp denorm_cmp = bld.copy(bld.def(s1), Operand::c32(0x00800000)); + Temp is_denormal = bld.sopc(aco_opcode::s_cmp_lt_u32, bld.def(s1, scc), abs, denorm_cmp); + scale = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), + bld.copy(bld.def(s1), Operand::c32(0x4b800000u)), Operand::c32(0x3f800000), + bld.scc(is_denormal)); + unscale = + bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), bld.copy(bld.def(s1), Operand::c32(undo)), + Operand::c32(0x3f800000), bld.scc(is_denormal)); + } - Temp not_scaled = bld.vop1(op, bld.def(v1), val); - - bld.vop2(aco_opcode::v_cndmask_b32, dst, not_scaled, scaled, is_denormal); + if (dst.regClass() == v1) { + Temp scaled = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), scale, as_vgpr(bld, val)); + scaled = bld.vop1(vop, bld.def(v1), scaled); + bld.vop2(aco_opcode::v_mul_f32, dst, unscale, scaled); + } else { + assert(ctx->options->gfx_level >= GFX11_5); + Temp scaled = bld.sop2(aco_opcode::s_mul_f32, bld.def(s1), scale, val); + if (ctx->options->gfx_level >= GFX12) + scaled = bld.vop3(sop, bld.def(s1), scaled); + else + scaled = bld.as_uniform(bld.vop1(vop, bld.def(v1), scaled)); + bld.sop2(aco_opcode::s_mul_f32, dst, unscale, scaled); + } } void emit_rcp(isel_context* ctx, Builder& bld, Definition dst, Temp val) { - if (ctx->block->fp_mode.denorm32 == 0) { - bld.vop1(aco_opcode::v_rcp_f32, dst, val); - return; - } - - emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rcp_f32, 0x4b800000u); + emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rcp_f32, aco_opcode::v_s_rcp_f32, 0x4b800000u); } void emit_rsq(isel_context* ctx, Builder& bld, Definition dst, Temp val) { - if (ctx->block->fp_mode.denorm32 == 0) { - bld.vop1(aco_opcode::v_rsq_f32, dst, val); - return; - } - - emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rsq_f32, 0x45800000u); + emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rsq_f32, aco_opcode::v_s_rsq_f32, 0x45800000u); } void emit_sqrt(isel_context* ctx, Builder& bld, Definition dst, Temp val) { - if (ctx->block->fp_mode.denorm32 == 0) { - bld.vop1(aco_opcode::v_sqrt_f32, dst, val); - return; - } - - emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_sqrt_f32, 0x39800000u); + emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_sqrt_f32, aco_opcode::v_s_sqrt_f32, + 0x39800000u); } void emit_log2(isel_context* ctx, Builder& bld, Definition dst, Temp val) { - if (ctx->block->fp_mode.denorm32 == 0) { - bld.vop1(aco_opcode::v_log_f32, dst, val); - return; - } - - emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_log_f32, 0xc1c00000u); + emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_log_f32, aco_opcode::v_s_log_f32, 0xc1c00000u); } Temp @@ -2544,12 +2561,14 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) break; } case nir_op_frsq: { - if (dst.regClass() == v2b) { - emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f16, dst); - } else if (dst.regClass() == v1) { - Temp src = get_alu_src(ctx, instr->src[0]); - emit_rsq(ctx, bld, Definition(dst), src); - } else if (dst.regClass() == v2) { + if (instr->def.bit_size == 16) { + if (dst.regClass() == s1 && ctx->program->gfx_level >= GFX12) + bld.vop3(aco_opcode::v_s_rsq_f16, Definition(dst), get_alu_src(ctx, instr->src[0])); + else + emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f16, dst); + } else if (instr->def.bit_size == 32) { + emit_rsq(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0])); + } else if (instr->def.bit_size == 64) { /* Lowered at NIR level for precision reasons. */ emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f64, dst); } else { @@ -2663,23 +2682,27 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) break; } case nir_op_flog2: { - if (dst.regClass() == v2b) { - emit_vop1_instruction(ctx, instr, aco_opcode::v_log_f16, dst); - } else if (dst.regClass() == v1) { - Temp src = get_alu_src(ctx, instr->src[0]); - emit_log2(ctx, bld, Definition(dst), src); + if (instr->def.bit_size == 16) { + if (dst.regClass() == s1 && ctx->program->gfx_level >= GFX12) + bld.vop3(aco_opcode::v_s_log_f16, Definition(dst), get_alu_src(ctx, instr->src[0])); + else + emit_vop1_instruction(ctx, instr, aco_opcode::v_log_f16, dst); + } else if (instr->def.bit_size == 32) { + emit_log2(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0])); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } break; } case nir_op_frcp: { - if (dst.regClass() == v2b) { - emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f16, dst); - } else if (dst.regClass() == v1) { - Temp src = get_alu_src(ctx, instr->src[0]); - emit_rcp(ctx, bld, Definition(dst), src); - } else if (dst.regClass() == v2) { + if (instr->def.bit_size == 16) { + if (dst.regClass() == s1 && ctx->program->gfx_level >= GFX12) + bld.vop3(aco_opcode::v_s_rcp_f16, Definition(dst), get_alu_src(ctx, instr->src[0])); + else + emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f16, dst); + } else if (instr->def.bit_size == 32) { + emit_rcp(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0])); + } else if (instr->def.bit_size == 64) { /* Lowered at NIR level for precision reasons. */ emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f64, dst); } else { @@ -2688,9 +2711,13 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) break; } case nir_op_fexp2: { - if (dst.regClass() == v2b) { + if (dst.regClass() == s1 && ctx->options->gfx_level >= GFX12) { + aco_opcode opcode = + instr->def.bit_size == 16 ? aco_opcode::v_s_exp_f16 : aco_opcode::v_s_exp_f32; + bld.vop3(opcode, Definition(dst), get_alu_src(ctx, instr->src[0])); + } else if (instr->def.bit_size == 16) { emit_vop1_instruction(ctx, instr, aco_opcode::v_exp_f16, dst); - } else if (dst.regClass() == v1) { + } else if (instr->def.bit_size == 32) { emit_vop1_instruction(ctx, instr, aco_opcode::v_exp_f32, dst); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); @@ -2698,12 +2725,14 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) break; } case nir_op_fsqrt: { - if (dst.regClass() == v2b) { - emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f16, dst); - } else if (dst.regClass() == v1) { - Temp src = get_alu_src(ctx, instr->src[0]); - emit_sqrt(ctx, bld, Definition(dst), src); - } else if (dst.regClass() == v2) { + if (instr->def.bit_size == 16) { + if (dst.regClass() == s1 && ctx->program->gfx_level >= GFX12) + bld.vop3(aco_opcode::v_s_sqrt_f16, Definition(dst), get_alu_src(ctx, instr->src[0])); + else + emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f16, dst); + } else if (instr->def.bit_size == 32) { + emit_sqrt(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0])); + } else if (instr->def.bit_size == 64) { /* Lowered at NIR level for precision reasons. */ emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f64, dst); } else { @@ -2858,20 +2887,31 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) } case nir_op_fsin_amd: case nir_op_fcos_amd: { - Temp src = as_vgpr(ctx, get_alu_src(ctx, instr->src[0])); - aco_ptr norm; - if (dst.regClass() == v2b) { - aco_opcode opcode = - instr->op == nir_op_fsin_amd ? aco_opcode::v_sin_f16 : aco_opcode::v_cos_f16; - bld.vop1(opcode, Definition(dst), src); - } else if (dst.regClass() == v1) { - /* before GFX9, v_sin_f32 and v_cos_f32 had a valid input domain of [-256, +256] */ - if (ctx->options->gfx_level < GFX9) - src = bld.vop1(aco_opcode::v_fract_f32, bld.def(v1), src); + if (instr->def.bit_size == 16 || instr->def.bit_size == 32) { + bool is_sin = instr->op == nir_op_fsin_amd; + aco_opcode opcode, fract; + RegClass rc; + if (instr->def.bit_size == 16) { + opcode = is_sin ? aco_opcode::v_sin_f16 : aco_opcode::v_cos_f16; + fract = aco_opcode::v_fract_f16; + rc = v2b; + } else { + opcode = is_sin ? aco_opcode::v_sin_f32 : aco_opcode::v_cos_f32; + fract = aco_opcode::v_fract_f32; + rc = v1; + } - aco_opcode opcode = - instr->op == nir_op_fsin_amd ? aco_opcode::v_sin_f32 : aco_opcode::v_cos_f32; - bld.vop1(opcode, Definition(dst), src); + Temp src = get_alu_src(ctx, instr->src[0]); + /* before GFX9, v_sin and v_cos had a valid input domain of [-256, +256] */ + if (ctx->options->gfx_level < GFX9) + src = bld.vop1(fract, bld.def(rc), src); + + if (dst.regClass() == rc) { + bld.vop1(opcode, Definition(dst), src); + } else { + Temp tmp = bld.vop1(opcode, bld.def(rc), src); + bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), tmp); + } } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } diff --git a/src/amd/compiler/aco_instruction_selection_setup.cpp b/src/amd/compiler/aco_instruction_selection_setup.cpp index f06cbaf6453..1593cb3c127 100644 --- a/src/amd/compiler/aco_instruction_selection_setup.cpp +++ b/src/amd/compiler/aco_instruction_selection_setup.cpp @@ -332,13 +332,6 @@ init_context(isel_context* ctx, nir_shader* shader) case nir_op_ffmaz: case nir_op_fneg: case nir_op_fabs: - case nir_op_frcp: - case nir_op_frsq: - case nir_op_fsqrt: - case nir_op_fexp2: - case nir_op_flog2: - case nir_op_fsin_amd: - case nir_op_fcos_amd: case nir_op_f2f64: case nir_op_u2f64: case nir_op_i2f64: @@ -390,6 +383,13 @@ init_context(isel_context* ctx, nir_shader* shader) case nir_op_fceil: case nir_op_ftrunc: case nir_op_fround_even: + case nir_op_frcp: + case nir_op_frsq: + case nir_op_fsqrt: + case nir_op_fexp2: + case nir_op_flog2: + case nir_op_fsin_amd: + case nir_op_fcos_amd: case nir_op_pack_half_2x16_rtz_split: case nir_op_pack_half_2x16_split: case nir_op_unpack_half_2x16_split_x: