aco/gfx11.5+: allow sgpr dst for trans ops and use pseudo scalar ops on gfx12

Also optimize the denorm scaling path by only emitting the expensive trans op once
and allowing fma for the final muliplication.

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29245>
This commit is contained in:
Georg Lehmann
2023-09-22 11:28:15 +02:00
committed by Marge Bot
parent 314053a3e3
commit 284b9965e8
2 changed files with 123 additions and 83 deletions

View File

@@ -1152,67 +1152,84 @@ emit_bcsel(isel_context* ctx, nir_alu_instr* instr, Temp dst)
}
void
emit_scaled_op(isel_context* ctx, Builder& bld, Definition dst, Temp val, aco_opcode op,
uint32_t undo)
emit_scaled_op(isel_context* ctx, Builder& bld, Definition dst, Temp val, aco_opcode vop,
aco_opcode sop, uint32_t undo)
{
if (ctx->block->fp_mode.denorm32 == 0) {
if (dst.regClass() == v1)
bld.vop1(vop, dst, val);
else if (ctx->options->gfx_level >= GFX12)
bld.vop3(sop, dst, val);
else
bld.pseudo(aco_opcode::p_as_uniform, dst, bld.vop1(vop, bld.def(v1), val));
return;
}
/* multiply by 16777216 to handle denormals */
Temp is_denormal = bld.tmp(bld.lm);
VALU_instruction& valu =
bld.vopc_e64(aco_opcode::v_cmp_class_f32, Definition(is_denormal), val, Operand::c32(1u << 4))
->valu();
valu.neg[0] = true;
valu.abs[0] = true;
Temp scaled = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand::c32(0x4b800000u), val);
scaled = bld.vop1(op, bld.def(v1), scaled);
scaled = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand::c32(undo), scaled);
Temp scale, unscale;
if (val.regClass() == v1) {
val = as_vgpr(bld, val);
Temp is_denormal = bld.tmp(bld.lm);
VALU_instruction& valu = bld.vopc_e64(aco_opcode::v_cmp_class_f32, Definition(is_denormal),
val, Operand::c32(1u << 4))
->valu();
valu.neg[0] = true;
valu.abs[0] = true;
scale = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand::c32(0x3f800000),
bld.copy(bld.def(s1), Operand::c32(0x4b800000u)), is_denormal);
unscale = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand::c32(0x3f800000),
bld.copy(bld.def(s1), Operand::c32(undo)), is_denormal);
} else {
Temp abs = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), val,
bld.copy(bld.def(s1), Operand::c32(0x7fffffff)));
Temp denorm_cmp = bld.copy(bld.def(s1), Operand::c32(0x00800000));
Temp is_denormal = bld.sopc(aco_opcode::s_cmp_lt_u32, bld.def(s1, scc), abs, denorm_cmp);
scale = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1),
bld.copy(bld.def(s1), Operand::c32(0x4b800000u)), Operand::c32(0x3f800000),
bld.scc(is_denormal));
unscale =
bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), bld.copy(bld.def(s1), Operand::c32(undo)),
Operand::c32(0x3f800000), bld.scc(is_denormal));
}
Temp not_scaled = bld.vop1(op, bld.def(v1), val);
bld.vop2(aco_opcode::v_cndmask_b32, dst, not_scaled, scaled, is_denormal);
if (dst.regClass() == v1) {
Temp scaled = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), scale, as_vgpr(bld, val));
scaled = bld.vop1(vop, bld.def(v1), scaled);
bld.vop2(aco_opcode::v_mul_f32, dst, unscale, scaled);
} else {
assert(ctx->options->gfx_level >= GFX11_5);
Temp scaled = bld.sop2(aco_opcode::s_mul_f32, bld.def(s1), scale, val);
if (ctx->options->gfx_level >= GFX12)
scaled = bld.vop3(sop, bld.def(s1), scaled);
else
scaled = bld.as_uniform(bld.vop1(vop, bld.def(v1), scaled));
bld.sop2(aco_opcode::s_mul_f32, dst, unscale, scaled);
}
}
void
emit_rcp(isel_context* ctx, Builder& bld, Definition dst, Temp val)
{
if (ctx->block->fp_mode.denorm32 == 0) {
bld.vop1(aco_opcode::v_rcp_f32, dst, val);
return;
}
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rcp_f32, 0x4b800000u);
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rcp_f32, aco_opcode::v_s_rcp_f32, 0x4b800000u);
}
void
emit_rsq(isel_context* ctx, Builder& bld, Definition dst, Temp val)
{
if (ctx->block->fp_mode.denorm32 == 0) {
bld.vop1(aco_opcode::v_rsq_f32, dst, val);
return;
}
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rsq_f32, 0x45800000u);
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rsq_f32, aco_opcode::v_s_rsq_f32, 0x45800000u);
}
void
emit_sqrt(isel_context* ctx, Builder& bld, Definition dst, Temp val)
{
if (ctx->block->fp_mode.denorm32 == 0) {
bld.vop1(aco_opcode::v_sqrt_f32, dst, val);
return;
}
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_sqrt_f32, 0x39800000u);
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_sqrt_f32, aco_opcode::v_s_sqrt_f32,
0x39800000u);
}
void
emit_log2(isel_context* ctx, Builder& bld, Definition dst, Temp val)
{
if (ctx->block->fp_mode.denorm32 == 0) {
bld.vop1(aco_opcode::v_log_f32, dst, val);
return;
}
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_log_f32, 0xc1c00000u);
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_log_f32, aco_opcode::v_s_log_f32, 0xc1c00000u);
}
Temp
@@ -2544,12 +2561,14 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
break;
}
case nir_op_frsq: {
if (dst.regClass() == v2b) {
emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f16, dst);
} else if (dst.regClass() == v1) {
Temp src = get_alu_src(ctx, instr->src[0]);
emit_rsq(ctx, bld, Definition(dst), src);
} else if (dst.regClass() == v2) {
if (instr->def.bit_size == 16) {
if (dst.regClass() == s1 && ctx->program->gfx_level >= GFX12)
bld.vop3(aco_opcode::v_s_rsq_f16, Definition(dst), get_alu_src(ctx, instr->src[0]));
else
emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f16, dst);
} else if (instr->def.bit_size == 32) {
emit_rsq(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
} else if (instr->def.bit_size == 64) {
/* Lowered at NIR level for precision reasons. */
emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f64, dst);
} else {
@@ -2663,23 +2682,27 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
break;
}
case nir_op_flog2: {
if (dst.regClass() == v2b) {
emit_vop1_instruction(ctx, instr, aco_opcode::v_log_f16, dst);
} else if (dst.regClass() == v1) {
Temp src = get_alu_src(ctx, instr->src[0]);
emit_log2(ctx, bld, Definition(dst), src);
if (instr->def.bit_size == 16) {
if (dst.regClass() == s1 && ctx->program->gfx_level >= GFX12)
bld.vop3(aco_opcode::v_s_log_f16, Definition(dst), get_alu_src(ctx, instr->src[0]));
else
emit_vop1_instruction(ctx, instr, aco_opcode::v_log_f16, dst);
} else if (instr->def.bit_size == 32) {
emit_log2(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
break;
}
case nir_op_frcp: {
if (dst.regClass() == v2b) {
emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f16, dst);
} else if (dst.regClass() == v1) {
Temp src = get_alu_src(ctx, instr->src[0]);
emit_rcp(ctx, bld, Definition(dst), src);
} else if (dst.regClass() == v2) {
if (instr->def.bit_size == 16) {
if (dst.regClass() == s1 && ctx->program->gfx_level >= GFX12)
bld.vop3(aco_opcode::v_s_rcp_f16, Definition(dst), get_alu_src(ctx, instr->src[0]));
else
emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f16, dst);
} else if (instr->def.bit_size == 32) {
emit_rcp(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
} else if (instr->def.bit_size == 64) {
/* Lowered at NIR level for precision reasons. */
emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f64, dst);
} else {
@@ -2688,9 +2711,13 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
break;
}
case nir_op_fexp2: {
if (dst.regClass() == v2b) {
if (dst.regClass() == s1 && ctx->options->gfx_level >= GFX12) {
aco_opcode opcode =
instr->def.bit_size == 16 ? aco_opcode::v_s_exp_f16 : aco_opcode::v_s_exp_f32;
bld.vop3(opcode, Definition(dst), get_alu_src(ctx, instr->src[0]));
} else if (instr->def.bit_size == 16) {
emit_vop1_instruction(ctx, instr, aco_opcode::v_exp_f16, dst);
} else if (dst.regClass() == v1) {
} else if (instr->def.bit_size == 32) {
emit_vop1_instruction(ctx, instr, aco_opcode::v_exp_f32, dst);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
@@ -2698,12 +2725,14 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
break;
}
case nir_op_fsqrt: {
if (dst.regClass() == v2b) {
emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f16, dst);
} else if (dst.regClass() == v1) {
Temp src = get_alu_src(ctx, instr->src[0]);
emit_sqrt(ctx, bld, Definition(dst), src);
} else if (dst.regClass() == v2) {
if (instr->def.bit_size == 16) {
if (dst.regClass() == s1 && ctx->program->gfx_level >= GFX12)
bld.vop3(aco_opcode::v_s_sqrt_f16, Definition(dst), get_alu_src(ctx, instr->src[0]));
else
emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f16, dst);
} else if (instr->def.bit_size == 32) {
emit_sqrt(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
} else if (instr->def.bit_size == 64) {
/* Lowered at NIR level for precision reasons. */
emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f64, dst);
} else {
@@ -2858,20 +2887,31 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
}
case nir_op_fsin_amd:
case nir_op_fcos_amd: {
Temp src = as_vgpr(ctx, get_alu_src(ctx, instr->src[0]));
aco_ptr<Instruction> norm;
if (dst.regClass() == v2b) {
aco_opcode opcode =
instr->op == nir_op_fsin_amd ? aco_opcode::v_sin_f16 : aco_opcode::v_cos_f16;
bld.vop1(opcode, Definition(dst), src);
} else if (dst.regClass() == v1) {
/* before GFX9, v_sin_f32 and v_cos_f32 had a valid input domain of [-256, +256] */
if (ctx->options->gfx_level < GFX9)
src = bld.vop1(aco_opcode::v_fract_f32, bld.def(v1), src);
if (instr->def.bit_size == 16 || instr->def.bit_size == 32) {
bool is_sin = instr->op == nir_op_fsin_amd;
aco_opcode opcode, fract;
RegClass rc;
if (instr->def.bit_size == 16) {
opcode = is_sin ? aco_opcode::v_sin_f16 : aco_opcode::v_cos_f16;
fract = aco_opcode::v_fract_f16;
rc = v2b;
} else {
opcode = is_sin ? aco_opcode::v_sin_f32 : aco_opcode::v_cos_f32;
fract = aco_opcode::v_fract_f32;
rc = v1;
}
aco_opcode opcode =
instr->op == nir_op_fsin_amd ? aco_opcode::v_sin_f32 : aco_opcode::v_cos_f32;
bld.vop1(opcode, Definition(dst), src);
Temp src = get_alu_src(ctx, instr->src[0]);
/* before GFX9, v_sin and v_cos had a valid input domain of [-256, +256] */
if (ctx->options->gfx_level < GFX9)
src = bld.vop1(fract, bld.def(rc), src);
if (dst.regClass() == rc) {
bld.vop1(opcode, Definition(dst), src);
} else {
Temp tmp = bld.vop1(opcode, bld.def(rc), src);
bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), tmp);
}
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}

View File

@@ -332,13 +332,6 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_op_ffmaz:
case nir_op_fneg:
case nir_op_fabs:
case nir_op_frcp:
case nir_op_frsq:
case nir_op_fsqrt:
case nir_op_fexp2:
case nir_op_flog2:
case nir_op_fsin_amd:
case nir_op_fcos_amd:
case nir_op_f2f64:
case nir_op_u2f64:
case nir_op_i2f64:
@@ -390,6 +383,13 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_op_fceil:
case nir_op_ftrunc:
case nir_op_fround_even:
case nir_op_frcp:
case nir_op_frsq:
case nir_op_fsqrt:
case nir_op_fexp2:
case nir_op_flog2:
case nir_op_fsin_amd:
case nir_op_fcos_amd:
case nir_op_pack_half_2x16_rtz_split:
case nir_op_pack_half_2x16_split:
case nir_op_unpack_half_2x16_split_x: