aco/gfx11.5+: allow sgpr dst for trans ops and use pseudo scalar ops on gfx12
Also optimize the denorm scaling path by only emitting the expensive trans op once and allowing fma for the final muliplication. Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29245>
This commit is contained in:
@@ -1152,67 +1152,84 @@ emit_bcsel(isel_context* ctx, nir_alu_instr* instr, Temp dst)
|
||||
}
|
||||
|
||||
void
|
||||
emit_scaled_op(isel_context* ctx, Builder& bld, Definition dst, Temp val, aco_opcode op,
|
||||
uint32_t undo)
|
||||
emit_scaled_op(isel_context* ctx, Builder& bld, Definition dst, Temp val, aco_opcode vop,
|
||||
aco_opcode sop, uint32_t undo)
|
||||
{
|
||||
if (ctx->block->fp_mode.denorm32 == 0) {
|
||||
if (dst.regClass() == v1)
|
||||
bld.vop1(vop, dst, val);
|
||||
else if (ctx->options->gfx_level >= GFX12)
|
||||
bld.vop3(sop, dst, val);
|
||||
else
|
||||
bld.pseudo(aco_opcode::p_as_uniform, dst, bld.vop1(vop, bld.def(v1), val));
|
||||
return;
|
||||
}
|
||||
|
||||
/* multiply by 16777216 to handle denormals */
|
||||
Temp scale, unscale;
|
||||
if (val.regClass() == v1) {
|
||||
val = as_vgpr(bld, val);
|
||||
Temp is_denormal = bld.tmp(bld.lm);
|
||||
VALU_instruction& valu =
|
||||
bld.vopc_e64(aco_opcode::v_cmp_class_f32, Definition(is_denormal), val, Operand::c32(1u << 4))
|
||||
VALU_instruction& valu = bld.vopc_e64(aco_opcode::v_cmp_class_f32, Definition(is_denormal),
|
||||
val, Operand::c32(1u << 4))
|
||||
->valu();
|
||||
valu.neg[0] = true;
|
||||
valu.abs[0] = true;
|
||||
Temp scaled = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand::c32(0x4b800000u), val);
|
||||
scaled = bld.vop1(op, bld.def(v1), scaled);
|
||||
scaled = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand::c32(undo), scaled);
|
||||
scale = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand::c32(0x3f800000),
|
||||
bld.copy(bld.def(s1), Operand::c32(0x4b800000u)), is_denormal);
|
||||
unscale = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand::c32(0x3f800000),
|
||||
bld.copy(bld.def(s1), Operand::c32(undo)), is_denormal);
|
||||
} else {
|
||||
Temp abs = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), val,
|
||||
bld.copy(bld.def(s1), Operand::c32(0x7fffffff)));
|
||||
Temp denorm_cmp = bld.copy(bld.def(s1), Operand::c32(0x00800000));
|
||||
Temp is_denormal = bld.sopc(aco_opcode::s_cmp_lt_u32, bld.def(s1, scc), abs, denorm_cmp);
|
||||
scale = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1),
|
||||
bld.copy(bld.def(s1), Operand::c32(0x4b800000u)), Operand::c32(0x3f800000),
|
||||
bld.scc(is_denormal));
|
||||
unscale =
|
||||
bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), bld.copy(bld.def(s1), Operand::c32(undo)),
|
||||
Operand::c32(0x3f800000), bld.scc(is_denormal));
|
||||
}
|
||||
|
||||
Temp not_scaled = bld.vop1(op, bld.def(v1), val);
|
||||
|
||||
bld.vop2(aco_opcode::v_cndmask_b32, dst, not_scaled, scaled, is_denormal);
|
||||
if (dst.regClass() == v1) {
|
||||
Temp scaled = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), scale, as_vgpr(bld, val));
|
||||
scaled = bld.vop1(vop, bld.def(v1), scaled);
|
||||
bld.vop2(aco_opcode::v_mul_f32, dst, unscale, scaled);
|
||||
} else {
|
||||
assert(ctx->options->gfx_level >= GFX11_5);
|
||||
Temp scaled = bld.sop2(aco_opcode::s_mul_f32, bld.def(s1), scale, val);
|
||||
if (ctx->options->gfx_level >= GFX12)
|
||||
scaled = bld.vop3(sop, bld.def(s1), scaled);
|
||||
else
|
||||
scaled = bld.as_uniform(bld.vop1(vop, bld.def(v1), scaled));
|
||||
bld.sop2(aco_opcode::s_mul_f32, dst, unscale, scaled);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
emit_rcp(isel_context* ctx, Builder& bld, Definition dst, Temp val)
|
||||
{
|
||||
if (ctx->block->fp_mode.denorm32 == 0) {
|
||||
bld.vop1(aco_opcode::v_rcp_f32, dst, val);
|
||||
return;
|
||||
}
|
||||
|
||||
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rcp_f32, 0x4b800000u);
|
||||
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rcp_f32, aco_opcode::v_s_rcp_f32, 0x4b800000u);
|
||||
}
|
||||
|
||||
void
|
||||
emit_rsq(isel_context* ctx, Builder& bld, Definition dst, Temp val)
|
||||
{
|
||||
if (ctx->block->fp_mode.denorm32 == 0) {
|
||||
bld.vop1(aco_opcode::v_rsq_f32, dst, val);
|
||||
return;
|
||||
}
|
||||
|
||||
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rsq_f32, 0x45800000u);
|
||||
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rsq_f32, aco_opcode::v_s_rsq_f32, 0x45800000u);
|
||||
}
|
||||
|
||||
void
|
||||
emit_sqrt(isel_context* ctx, Builder& bld, Definition dst, Temp val)
|
||||
{
|
||||
if (ctx->block->fp_mode.denorm32 == 0) {
|
||||
bld.vop1(aco_opcode::v_sqrt_f32, dst, val);
|
||||
return;
|
||||
}
|
||||
|
||||
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_sqrt_f32, 0x39800000u);
|
||||
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_sqrt_f32, aco_opcode::v_s_sqrt_f32,
|
||||
0x39800000u);
|
||||
}
|
||||
|
||||
void
|
||||
emit_log2(isel_context* ctx, Builder& bld, Definition dst, Temp val)
|
||||
{
|
||||
if (ctx->block->fp_mode.denorm32 == 0) {
|
||||
bld.vop1(aco_opcode::v_log_f32, dst, val);
|
||||
return;
|
||||
}
|
||||
|
||||
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_log_f32, 0xc1c00000u);
|
||||
emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_log_f32, aco_opcode::v_s_log_f32, 0xc1c00000u);
|
||||
}
|
||||
|
||||
Temp
|
||||
@@ -2544,12 +2561,14 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
|
||||
break;
|
||||
}
|
||||
case nir_op_frsq: {
|
||||
if (dst.regClass() == v2b) {
|
||||
if (instr->def.bit_size == 16) {
|
||||
if (dst.regClass() == s1 && ctx->program->gfx_level >= GFX12)
|
||||
bld.vop3(aco_opcode::v_s_rsq_f16, Definition(dst), get_alu_src(ctx, instr->src[0]));
|
||||
else
|
||||
emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f16, dst);
|
||||
} else if (dst.regClass() == v1) {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
emit_rsq(ctx, bld, Definition(dst), src);
|
||||
} else if (dst.regClass() == v2) {
|
||||
} else if (instr->def.bit_size == 32) {
|
||||
emit_rsq(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
|
||||
} else if (instr->def.bit_size == 64) {
|
||||
/* Lowered at NIR level for precision reasons. */
|
||||
emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f64, dst);
|
||||
} else {
|
||||
@@ -2663,23 +2682,27 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
|
||||
break;
|
||||
}
|
||||
case nir_op_flog2: {
|
||||
if (dst.regClass() == v2b) {
|
||||
if (instr->def.bit_size == 16) {
|
||||
if (dst.regClass() == s1 && ctx->program->gfx_level >= GFX12)
|
||||
bld.vop3(aco_opcode::v_s_log_f16, Definition(dst), get_alu_src(ctx, instr->src[0]));
|
||||
else
|
||||
emit_vop1_instruction(ctx, instr, aco_opcode::v_log_f16, dst);
|
||||
} else if (dst.regClass() == v1) {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
emit_log2(ctx, bld, Definition(dst), src);
|
||||
} else if (instr->def.bit_size == 32) {
|
||||
emit_log2(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
|
||||
} else {
|
||||
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
|
||||
}
|
||||
break;
|
||||
}
|
||||
case nir_op_frcp: {
|
||||
if (dst.regClass() == v2b) {
|
||||
if (instr->def.bit_size == 16) {
|
||||
if (dst.regClass() == s1 && ctx->program->gfx_level >= GFX12)
|
||||
bld.vop3(aco_opcode::v_s_rcp_f16, Definition(dst), get_alu_src(ctx, instr->src[0]));
|
||||
else
|
||||
emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f16, dst);
|
||||
} else if (dst.regClass() == v1) {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
emit_rcp(ctx, bld, Definition(dst), src);
|
||||
} else if (dst.regClass() == v2) {
|
||||
} else if (instr->def.bit_size == 32) {
|
||||
emit_rcp(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
|
||||
} else if (instr->def.bit_size == 64) {
|
||||
/* Lowered at NIR level for precision reasons. */
|
||||
emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f64, dst);
|
||||
} else {
|
||||
@@ -2688,9 +2711,13 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
|
||||
break;
|
||||
}
|
||||
case nir_op_fexp2: {
|
||||
if (dst.regClass() == v2b) {
|
||||
if (dst.regClass() == s1 && ctx->options->gfx_level >= GFX12) {
|
||||
aco_opcode opcode =
|
||||
instr->def.bit_size == 16 ? aco_opcode::v_s_exp_f16 : aco_opcode::v_s_exp_f32;
|
||||
bld.vop3(opcode, Definition(dst), get_alu_src(ctx, instr->src[0]));
|
||||
} else if (instr->def.bit_size == 16) {
|
||||
emit_vop1_instruction(ctx, instr, aco_opcode::v_exp_f16, dst);
|
||||
} else if (dst.regClass() == v1) {
|
||||
} else if (instr->def.bit_size == 32) {
|
||||
emit_vop1_instruction(ctx, instr, aco_opcode::v_exp_f32, dst);
|
||||
} else {
|
||||
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
|
||||
@@ -2698,12 +2725,14 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
|
||||
break;
|
||||
}
|
||||
case nir_op_fsqrt: {
|
||||
if (dst.regClass() == v2b) {
|
||||
if (instr->def.bit_size == 16) {
|
||||
if (dst.regClass() == s1 && ctx->program->gfx_level >= GFX12)
|
||||
bld.vop3(aco_opcode::v_s_sqrt_f16, Definition(dst), get_alu_src(ctx, instr->src[0]));
|
||||
else
|
||||
emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f16, dst);
|
||||
} else if (dst.regClass() == v1) {
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
emit_sqrt(ctx, bld, Definition(dst), src);
|
||||
} else if (dst.regClass() == v2) {
|
||||
} else if (instr->def.bit_size == 32) {
|
||||
emit_sqrt(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
|
||||
} else if (instr->def.bit_size == 64) {
|
||||
/* Lowered at NIR level for precision reasons. */
|
||||
emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f64, dst);
|
||||
} else {
|
||||
@@ -2858,20 +2887,31 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
|
||||
}
|
||||
case nir_op_fsin_amd:
|
||||
case nir_op_fcos_amd: {
|
||||
Temp src = as_vgpr(ctx, get_alu_src(ctx, instr->src[0]));
|
||||
aco_ptr<Instruction> norm;
|
||||
if (dst.regClass() == v2b) {
|
||||
aco_opcode opcode =
|
||||
instr->op == nir_op_fsin_amd ? aco_opcode::v_sin_f16 : aco_opcode::v_cos_f16;
|
||||
bld.vop1(opcode, Definition(dst), src);
|
||||
} else if (dst.regClass() == v1) {
|
||||
/* before GFX9, v_sin_f32 and v_cos_f32 had a valid input domain of [-256, +256] */
|
||||
if (ctx->options->gfx_level < GFX9)
|
||||
src = bld.vop1(aco_opcode::v_fract_f32, bld.def(v1), src);
|
||||
if (instr->def.bit_size == 16 || instr->def.bit_size == 32) {
|
||||
bool is_sin = instr->op == nir_op_fsin_amd;
|
||||
aco_opcode opcode, fract;
|
||||
RegClass rc;
|
||||
if (instr->def.bit_size == 16) {
|
||||
opcode = is_sin ? aco_opcode::v_sin_f16 : aco_opcode::v_cos_f16;
|
||||
fract = aco_opcode::v_fract_f16;
|
||||
rc = v2b;
|
||||
} else {
|
||||
opcode = is_sin ? aco_opcode::v_sin_f32 : aco_opcode::v_cos_f32;
|
||||
fract = aco_opcode::v_fract_f32;
|
||||
rc = v1;
|
||||
}
|
||||
|
||||
aco_opcode opcode =
|
||||
instr->op == nir_op_fsin_amd ? aco_opcode::v_sin_f32 : aco_opcode::v_cos_f32;
|
||||
Temp src = get_alu_src(ctx, instr->src[0]);
|
||||
/* before GFX9, v_sin and v_cos had a valid input domain of [-256, +256] */
|
||||
if (ctx->options->gfx_level < GFX9)
|
||||
src = bld.vop1(fract, bld.def(rc), src);
|
||||
|
||||
if (dst.regClass() == rc) {
|
||||
bld.vop1(opcode, Definition(dst), src);
|
||||
} else {
|
||||
Temp tmp = bld.vop1(opcode, bld.def(rc), src);
|
||||
bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), tmp);
|
||||
}
|
||||
} else {
|
||||
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
|
||||
}
|
||||
|
@@ -332,13 +332,6 @@ init_context(isel_context* ctx, nir_shader* shader)
|
||||
case nir_op_ffmaz:
|
||||
case nir_op_fneg:
|
||||
case nir_op_fabs:
|
||||
case nir_op_frcp:
|
||||
case nir_op_frsq:
|
||||
case nir_op_fsqrt:
|
||||
case nir_op_fexp2:
|
||||
case nir_op_flog2:
|
||||
case nir_op_fsin_amd:
|
||||
case nir_op_fcos_amd:
|
||||
case nir_op_f2f64:
|
||||
case nir_op_u2f64:
|
||||
case nir_op_i2f64:
|
||||
@@ -390,6 +383,13 @@ init_context(isel_context* ctx, nir_shader* shader)
|
||||
case nir_op_fceil:
|
||||
case nir_op_ftrunc:
|
||||
case nir_op_fround_even:
|
||||
case nir_op_frcp:
|
||||
case nir_op_frsq:
|
||||
case nir_op_fsqrt:
|
||||
case nir_op_fexp2:
|
||||
case nir_op_flog2:
|
||||
case nir_op_fsin_amd:
|
||||
case nir_op_fcos_amd:
|
||||
case nir_op_pack_half_2x16_rtz_split:
|
||||
case nir_op_pack_half_2x16_split:
|
||||
case nir_op_unpack_half_2x16_split_x:
|
||||
|
Reference in New Issue
Block a user