aco/gfx11.5: select SALU float conversions

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29245>
This commit is contained in:
Georg Lehmann
2023-09-21 20:38:00 +02:00
committed by Marge Bot
parent 4399c7bac3
commit a90d4d340c
2 changed files with 115 additions and 62 deletions

View File

@@ -1345,12 +1345,16 @@ emit_vec2_f2f16(isel_context* ctx, nir_alu_instr* instr, Temp dst)
Temp src0 = emit_extract_vector(ctx, src, instr->src[0].swizzle[0], rc);
Temp src1 = emit_extract_vector(ctx, src, instr->src[0].swizzle[1], rc);
src1 = as_vgpr(ctx, src1);
if (ctx->program->gfx_level == GFX8 || ctx->program->gfx_level == GFX9)
bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32_e64, Definition(dst), src0, src1);
else
bld.vop2(aco_opcode::v_cvt_pkrtz_f16_f32, Definition(dst), src0, src1);
emit_split_vector(ctx, dst, 2);
if (dst.regClass() == s1) {
bld.sop2(aco_opcode::s_cvt_pk_rtz_f16_f32, Definition(dst), src0, src1);
} else {
src1 = as_vgpr(ctx, src1);
if (ctx->program->gfx_level == GFX8 || ctx->program->gfx_level == GFX9)
bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32_e64, Definition(dst), src0, src1);
else
bld.vop2(aco_opcode::v_cvt_pkrtz_f16_f32, Definition(dst), src0, src1);
emit_split_vector(ctx, dst, 2);
}
}
void
@@ -2929,13 +2933,20 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
break;
}
Temp src = get_alu_src(ctx, instr->src[0]);
if (instr->op == nir_op_f2f16_rtne && ctx->block->fp_mode.round16_64 != fp_round_ne)
if (instr->op == nir_op_f2f16_rtne && ctx->block->fp_mode.round16_64 != fp_round_ne) {
/* We emit s_round_mode/s_setreg_imm32 in lower_to_hw_instr to
* keep value numbering and the scheduler simpler.
*/
bld.vop1(aco_opcode::p_v_cvt_f16_f32_rtne, Definition(dst), src);
else
bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src);
if (dst.regClass() == v2b)
bld.vop1(aco_opcode::p_v_cvt_f16_f32_rtne, Definition(dst), src);
else
bld.sop1(aco_opcode::p_s_cvt_f16_f32_rtne, Definition(dst), src);
} else {
if (dst.regClass() == v2b)
bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src);
else
bld.sop1(aco_opcode::s_cvt_f16_f32, Definition(dst), src);
}
break;
}
case nir_op_f2f16_rtz: {
@@ -2945,16 +2956,26 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
break;
}
Temp src = get_alu_src(ctx, instr->src[0]);
if (ctx->block->fp_mode.round16_64 == fp_round_tz)
bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src);
else if (ctx->program->gfx_level == GFX8 || ctx->program->gfx_level == GFX9)
if (ctx->block->fp_mode.round16_64 == fp_round_tz) {
if (dst.regClass() == v2b)
bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src);
else
bld.sop1(aco_opcode::s_cvt_f16_f32, Definition(dst), src);
} else if (dst.regClass() == s1) {
bld.sop2(aco_opcode::s_cvt_pk_rtz_f16_f32, Definition(dst), src, Operand::zero());
} else if (ctx->program->gfx_level == GFX8 || ctx->program->gfx_level == GFX9) {
bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32_e64, Definition(dst), src, Operand::zero());
else
} else {
bld.vop2(aco_opcode::v_cvt_pkrtz_f16_f32, Definition(dst), src, as_vgpr(ctx, src));
}
break;
}
case nir_op_f2f32: {
if (instr->src[0].src.ssa->bit_size == 16) {
if (dst.regClass() == s1) {
assert(instr->src[0].src.ssa->bit_size == 16);
Temp src = get_alu_src(ctx, instr->src[0]);
bld.sop1(aco_opcode::s_cvt_f32_f16, Definition(dst), src);
} else if (instr->src[0].src.ssa->bit_size == 16) {
emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_f16, dst);
} else if (instr->src[0].src.ssa->bit_size == 64) {
emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_f64, dst);
@@ -2970,27 +2991,36 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
break;
}
case nir_op_i2f16: {
assert(dst.regClass() == v2b);
Temp src = get_alu_src(ctx, instr->src[0]);
const unsigned input_size = instr->src[0].src.ssa->bit_size;
if (input_size <= 16) {
/* Expand integer to the size expected by the uint→float converter used below */
unsigned target_size = (ctx->program->gfx_level >= GFX8 ? 16 : 32);
if (input_size != target_size) {
src = convert_int(ctx, bld, src, input_size, target_size, true);
if (dst.regClass() == v2b) {
if (input_size <= 16) {
/* Expand integer to the size expected by the uint→float converter used below */
unsigned target_size = (ctx->program->gfx_level >= GFX8 ? 16 : 32);
if (input_size != target_size) {
src = convert_int(ctx, bld, src, input_size, target_size, true);
}
}
}
if (ctx->program->gfx_level >= GFX8 && input_size <= 16) {
bld.vop1(aco_opcode::v_cvt_f16_i16, Definition(dst), src);
if (ctx->program->gfx_level >= GFX8 && input_size <= 16) {
bld.vop1(aco_opcode::v_cvt_f16_i16, Definition(dst), src);
} else {
/* Large 32bit inputs need to return +-inf/FLOAT_MAX.
*
* This is also the fallback-path taken on GFX7 and earlier, which
* do not support direct f16⟷i16 conversions.
*/
src = bld.vop1(aco_opcode::v_cvt_f32_i32, bld.def(v1), src);
bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src);
}
} else if (dst.regClass() == s1) {
if (input_size <= 16) {
src = convert_int(ctx, bld, src, input_size, 32, true);
}
src = bld.sop1(aco_opcode::s_cvt_f32_i32, bld.def(s1), src);
bld.sop1(aco_opcode::s_cvt_f16_f32, Definition(dst), src);
} else {
/* Large 32bit inputs need to return +-inf/FLOAT_MAX.
*
* This is also the fallback-path taken on GFX7 and earlier, which
* do not support direct f16⟷i16 conversions.
*/
src = bld.vop1(aco_opcode::v_cvt_f32_i32, bld.def(v1), src);
bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src);
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
break;
}
@@ -3003,7 +3033,10 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
/* Sign-extend to 32-bits */
src = convert_int(ctx, bld, src, input_size, 32, true);
}
bld.vop1(aco_opcode::v_cvt_f32_i32, Definition(dst), src);
if (dst.regClass() == v1)
bld.vop1(aco_opcode::v_cvt_f32_i32, Definition(dst), src);
else
bld.sop1(aco_opcode::s_cvt_f32_i32, Definition(dst), src);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
@@ -3021,27 +3054,36 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
break;
}
case nir_op_u2f16: {
assert(dst.regClass() == v2b);
Temp src = get_alu_src(ctx, instr->src[0]);
const unsigned input_size = instr->src[0].src.ssa->bit_size;
if (input_size <= 16) {
/* Expand integer to the size expected by the uint→float converter used below */
unsigned target_size = (ctx->program->gfx_level >= GFX8 ? 16 : 32);
if (input_size != target_size) {
src = convert_int(ctx, bld, src, input_size, target_size, false);
if (dst.regClass() == v2b) {
if (input_size <= 16) {
/* Expand integer to the size expected by the uint→float converter used below */
unsigned target_size = (ctx->program->gfx_level >= GFX8 ? 16 : 32);
if (input_size != target_size) {
src = convert_int(ctx, bld, src, input_size, target_size, false);
}
}
}
if (ctx->program->gfx_level >= GFX8 && input_size <= 16) {
bld.vop1(aco_opcode::v_cvt_f16_u16, Definition(dst), src);
if (ctx->program->gfx_level >= GFX8 && input_size <= 16) {
bld.vop1(aco_opcode::v_cvt_f16_u16, Definition(dst), src);
} else {
/* Large 32bit inputs need to return inf/FLOAT_MAX.
*
* This is also the fallback-path taken on GFX7 and earlier, which
* do not support direct f16⟷u16 conversions.
*/
src = bld.vop1(aco_opcode::v_cvt_f32_u32, bld.def(v1), src);
bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src);
}
} else if (dst.regClass() == s1) {
if (input_size <= 16) {
src = convert_int(ctx, bld, src, input_size, 32, false);
}
src = bld.sop1(aco_opcode::s_cvt_f32_u32, bld.def(s1), src);
bld.sop1(aco_opcode::s_cvt_f16_f32, Definition(dst), src);
} else {
/* Large 32bit inputs need to return inf/FLOAT_MAX.
*
* This is also the fallback-path taken on GFX7 and earlier, which
* do not support direct f16⟷u16 conversions.
*/
src = bld.vop1(aco_opcode::v_cvt_f32_u32, bld.def(v1), src);
bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src);
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
break;
}
@@ -3049,12 +3091,15 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
assert(dst.size() == 1);
Temp src = get_alu_src(ctx, instr->src[0]);
const unsigned input_size = instr->src[0].src.ssa->bit_size;
if (input_size == 8) {
if (input_size == 8 && dst.regClass() == v1) {
bld.vop1(aco_opcode::v_cvt_f32_ubyte0, Definition(dst), src);
} else if (input_size <= 32) {
if (input_size == 16)
if (input_size <= 16)
src = convert_int(ctx, bld, src, instr->src[0].src.ssa->bit_size, 32, false);
bld.vop1(aco_opcode::v_cvt_f32_u32, Definition(dst), src);
if (dst.regClass() == v1)
bld.vop1(aco_opcode::v_cvt_f32_u32, Definition(dst), src);
else
bld.sop1(aco_opcode::s_cvt_f32_u32, Definition(dst), src);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
@@ -3416,6 +3461,10 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
}
case nir_op_unpack_half_2x16_split_x: {
Temp src = get_alu_src(ctx, instr->src[0]);
if (dst.regClass() == s1) {
bld.sop1(aco_opcode::s_cvt_f32_f16, Definition(dst), src);
break;
}
if (src.regClass() == v1)
src = bld.pseudo(aco_opcode::p_split_vector, bld.def(v2b), bld.def(v2b), src);
if (dst.regClass() == v1) {
@@ -3427,6 +3476,10 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
}
case nir_op_unpack_half_2x16_split_y: {
Temp src = get_alu_src(ctx, instr->src[0]);
if (dst.regClass() == s1) {
bld.sop1(aco_opcode::s_cvt_hi_f32_f16, Definition(dst), src);
break;
}
if (src.regClass() == s1)
src = bld.pseudo(aco_opcode::p_extract, bld.def(s1), bld.def(s1, scc), src,
Operand::c32(1u), Operand::c32(16u), Operand::zero());

View File

@@ -347,16 +347,8 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_op_flog2:
case nir_op_fsin_amd:
case nir_op_fcos_amd:
case nir_op_f2f16:
case nir_op_f2f16_rtz:
case nir_op_f2f16_rtne:
case nir_op_f2f32:
case nir_op_f2f64:
case nir_op_u2f16:
case nir_op_u2f32:
case nir_op_u2f64:
case nir_op_i2f16:
case nir_op_i2f32:
case nir_op_i2f64:
case nir_op_pack_half_2x16_rtz_split:
case nir_op_pack_half_2x16_split:
@@ -364,8 +356,6 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_op_pack_snorm_2x16:
case nir_op_pack_uint_2x16:
case nir_op_pack_sint_2x16:
case nir_op_unpack_half_2x16_split_x:
case nir_op_unpack_half_2x16_split_y:
case nir_op_fddx:
case nir_op_fddy:
case nir_op_fddx_fine:
@@ -389,11 +379,21 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_op_sdot_2x16_iadd:
case nir_op_udot_2x16_uadd_sat:
case nir_op_sdot_2x16_iadd_sat: type = RegType::vgpr; break;
case nir_op_i2f16:
case nir_op_i2f32:
case nir_op_u2f16:
case nir_op_u2f32:
case nir_op_f2f16:
case nir_op_f2f16_rtz:
case nir_op_f2f16_rtne:
case nir_op_f2f32:
case nir_op_ffract:
case nir_op_ffloor:
case nir_op_fceil:
case nir_op_ftrunc:
case nir_op_fround_even: {
case nir_op_fround_even:
case nir_op_unpack_half_2x16_split_x:
case nir_op_unpack_half_2x16_split_y: {
if (ctx->program->gfx_level < GFX11_5 ||
alu_instr->src[0].src.ssa->bit_size > 32) {
type = RegType::vgpr;