aco: add VINTERP instruction format

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17333>
This commit is contained in:
Rhys Perry
2022-06-17 13:53:08 +01:00
committed by Marge Bot
parent 55cd74d468
commit aadb7aef01
12 changed files with 141 additions and 10 deletions

View File

@@ -374,6 +374,24 @@ emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction* inst
}
break;
}
case Format::VINTERP_INREG: {
VINTERP_inreg_instruction& interp = instr->vinterp_inreg();
uint32_t encoding = (0b11001101 << 24);
encoding |= reg(ctx, instr->definitions[0], 8);
encoding |= (uint32_t)interp.wait_exp << 8;
encoding |= (uint32_t)interp.opsel << 11;
encoding |= (uint32_t)interp.clamp << 15;
encoding |= opcode << 16;
out.push_back(encoding);
encoding = 0;
for (unsigned i = 0; i < instr->operands.size(); i++)
encoding |= reg(ctx, instr->operands[i]) << (i * 9);
for (unsigned i = 0; i < 3; i++)
encoding |= interp.neg[i] << (29 + i);
out.push_back(encoding);
break;
}
case Format::DS: {
DS_instruction& ds = instr->ds();
uint32_t encoding = (0b110110 << 26);

View File

@@ -531,6 +531,7 @@ formats = [("pseudo", [Format.PSEUDO], 'Pseudo_instruction', list(itertools.prod
("vopc_sdwa", [Format.VOPC, Format.SDWA], 'SDWA_instruction', itertools.product([1, 2], [2])),
("vop3", [Format.VOP3], 'VOP3_instruction', [(1, 3), (1, 2), (1, 1), (2, 2)]),
("vop3p", [Format.VOP3P], 'VOP3P_instruction', [(1, 2), (1, 3)]),
("vinterp_inreg", [Format.VINTERP_INREG], 'VINTERP_inreg_instruction', [(1, 3)]),
("vintrp", [Format.VINTRP], 'VINTRP_instruction', [(1, 2), (1, 3)]),
("vop1_dpp", [Format.VOP1, Format.DPP16], 'DPP16_instruction', [(1, 1)]),
("vop2_dpp", [Format.VOP2, Format.DPP16], 'DPP16_instruction', itertools.product([1, 2], [2, 3])),

View File

@@ -758,6 +758,11 @@ handle_block(Program* program, Block& block, wait_ctx& ctx)
gen(instr.get(), ctx);
if (instr->format != Format::PSEUDO_BARRIER && !is_wait) {
if (instr->isVINTERP_INREG() && queued_imm.exp != wait_imm::unset_counter) {
instr->vinterp_inreg().wait_exp = MIN2(instr->vinterp_inreg().wait_exp, queued_imm.exp);
queued_imm.exp = wait_imm::unset_counter;
}
if (!queued_imm.empty())
emit_waitcnt(ctx, new_instructions, queued_imm);

View File

@@ -441,6 +441,11 @@ can_use_opsel(amd_gfx_level gfx_level, aco_opcode op, int idx)
case aco_opcode::v_mad_i32_i16: return idx >= 0 && idx < 2;
case aco_opcode::v_dot2_f16_f16:
case aco_opcode::v_dot2_bf16_bf16: return idx == -1 || idx == 2;
// TODO: This matches what LLVM allows. We should see if this matches what the hardware allows.
case aco_opcode::v_interp_p10_f16_f32_inreg:
case aco_opcode::v_interp_p10_rtz_f16_f32_inreg: return idx == 0 || idx == 2;
case aco_opcode::v_interp_p2_f16_f32_inreg:
case aco_opcode::v_interp_p2_rtz_f16_f32_inreg: return idx == -1 || idx == 0;
default: return false;
}
}
@@ -448,6 +453,8 @@ can_use_opsel(amd_gfx_level gfx_level, aco_opcode op, int idx)
bool
instr_is_16bit(amd_gfx_level gfx_level, aco_opcode op)
{
// TODO: VINTERP (v_interp_p2_f16_f32, v_interp_p2_rtz_f16_f32)
/* partial register writes are GFX9+, only */
if (gfx_level < GFX9)
return false;

View File

@@ -96,6 +96,7 @@ enum class Format : std::uint16_t {
/* Vector ALU Formats */
VOP3P = 20,
VINTERP_INREG = 21,
VOP1 = 1 << 8,
VOP2 = 1 << 9,
VOPC = 1 << 10,
@@ -1010,6 +1011,7 @@ struct Pseudo_branch_instruction;
struct Pseudo_barrier_instruction;
struct Pseudo_reduction_instruction;
struct VOP3P_instruction;
struct VINTERP_inreg_instruction;
struct VOP1_instruction;
struct VOP2_instruction;
struct VOPC_instruction;
@@ -1258,6 +1260,17 @@ struct Instruction {
return *(VOP3P_instruction*)this;
}
constexpr bool isVOP3P() const noexcept { return format == Format::VOP3P; }
VINTERP_inreg_instruction& vinterp_inreg() noexcept
{
assert(isVINTERP_INREG());
return *(VINTERP_inreg_instruction*)this;
}
const VINTERP_inreg_instruction& vinterp_inreg() const noexcept
{
assert(isVINTERP_INREG());
return *(VINTERP_inreg_instruction*)this;
}
constexpr bool isVINTERP_INREG() const noexcept { return format == Format::VINTERP_INREG; }
VOP1_instruction& vop1() noexcept
{
assert(isVOP1());
@@ -1446,6 +1459,14 @@ struct VOP3P_instruction : public Instruction {
};
static_assert(sizeof(VOP3P_instruction) == sizeof(Instruction) + 8, "Unexpected padding");
struct VINTERP_inreg_instruction : public Instruction {
uint8_t wait_exp : 3;
bool clamp : 1;
uint8_t opsel : 4;
bool neg[3];
};
static_assert(sizeof(VINTERP_inreg_instruction) == sizeof(Instruction) + 4, "Unexpected padding");
/**
* Data Parallel Primitives Format:
* This format can be used for VOP1, VOP2 or VOPC instructions.

View File

@@ -2414,7 +2414,7 @@ lower_to_hw_instr(Program* program)
can_remove = false;
} else if (inst->isSALU()) {
num_scalar++;
} else if (inst->isVALU() || inst->isVINTRP()) {
} else if (inst->isVALU() || inst->isVINTRP() || instr->isVINTERP_INREG()) {
num_vector++;
/* VALU which writes SGPRs are always executed on GFX10+ */
if (ctx.program->gfx_level >= GFX10) {

View File

@@ -70,6 +70,7 @@ class Format(Enum):
PSEUDO_BARRIER = 18
PSEUDO_REDUCTION = 19
VOP3P = 20
VINTERP_INREG = 21
VOP1 = 1 << 8
VOP2 = 1 << 9
VOPC = 1 << 10
@@ -163,6 +164,9 @@ class Format(Enum):
elif self == Format.VOP3P:
return [('uint8_t', 'opsel_lo', None),
('uint8_t', 'opsel_hi', None)]
elif self == Format.VINTERP_INREG:
return [('unsigned', 'wait_exp', 7),
('uint8_t', 'opsel', 0)]
elif self in [Format.FLAT, Format.GLOBAL, Format.SCRATCH]:
return [('int16_t', 'offset', 0),
('memory_sync_info', 'sync', 'memory_sync_info()'),
@@ -999,7 +1003,7 @@ opcode("v_dot2_f32_f16", -1, 0x23, 0x13, 0x13, Format.VOP3P, InstrClass.Valu32)
opcode("v_dot2_f32_bf16", -1, -1, -1, 0x1a, Format.VOP3P, InstrClass.Valu32)
# VINTERP instructions:
# VINTRP (GFX6 - GFX10.3) instructions:
VINTRP = {
(0x00, "v_interp_p1_f32"),
(0x01, "v_interp_p2_f32"),
@@ -1009,6 +1013,20 @@ VINTRP = {
for (code, name) in VINTRP:
opcode(name, code, code, code, -1, Format.VINTRP, InstrClass.Valu32)
# VINTERP (GFX11+) instructions:
VINTERP = {
(0x00, "v_interp_p10_f32_inreg"),
(0x01, "v_interp_p2_f32_inreg"),
(0x02, "v_interp_p10_f16_f32_inreg"),
(0x03, "v_interp_p2_f16_f32_inreg"),
(0x04, "v_interp_p10_rtz_f16_f32_inreg"),
(0x05, "v_interp_p2_rtz_f16_f32_inreg"),
}
for (code, name) in VINTERP:
opcode(name, -1, -1, -1, code, Format.VINTERP_INREG, InstrClass.Valu32)
# VOP3 instructions: 3 inputs, 1 output
# VOP3b instructions: have a unique scalar output, e.g. VOP2 with vcc out
VOP3 = {

View File

@@ -99,6 +99,7 @@ struct InstrHash {
switch (instr->format) {
case Format::SMEM: return hash_murmur_32<SMEM_instruction>(instr);
case Format::VINTRP: return hash_murmur_32<VINTRP_instruction>(instr);
case Format::VINTERP_INREG: return hash_murmur_32<VINTERP_inreg_instruction>(instr);
case Format::DS: return hash_murmur_32<DS_instruction>(instr);
case Format::SOPP: return hash_murmur_32<SOPP_instruction>(instr);
case Format::SOPK: return hash_murmur_32<SOPK_instruction>(instr);
@@ -235,6 +236,12 @@ struct InstrPred {
return a3P.opsel_lo == b3P.opsel_lo && a3P.opsel_hi == b3P.opsel_hi &&
a3P.clamp == b3P.clamp;
}
case Format::VINTERP_INREG: {
VINTERP_inreg_instruction& aI = a->vinterp_inreg();
VINTERP_inreg_instruction& bI = b->vinterp_inreg();
return aI.wait_exp == bI.wait_exp && aI.clamp == bI.clamp && aI.opsel == bI.opsel &&
aI.neg[0] == bI.neg[0] && aI.neg[1] == bI.neg[1] && aI.neg[2] == bI.neg[2];
}
case Format::PSEUDO_REDUCTION: {
Pseudo_reduction_instruction& aR = a->reduction();
Pseudo_reduction_instruction& bR = b->reduction();

View File

@@ -347,6 +347,12 @@ print_instr_format_specific(const Instruction* instr, FILE* output)
print_sync(smem.sync, output);
break;
}
case Format::VINTERP_INREG: {
const VINTERP_inreg_instruction& vinterp = instr->vinterp_inreg();
if (vinterp.wait_exp != 7)
fprintf(output, " wait_exp:%u", vinterp.wait_exp);
break;
}
case Format::VINTRP: {
const VINTRP_instruction& vintrp = instr->vintrp();
fprintf(output, " attr%d.%c", vintrp.attribute, "xyzw"[vintrp.component]);
@@ -655,6 +661,12 @@ print_instr_format_specific(const Instruction* instr, FILE* output)
default: break;
}
}
} else if (instr->isVINTERP_INREG()) {
const VINTERP_inreg_instruction& vinterp = instr->vinterp_inreg();
if (vinterp.clamp)
fprintf(output, " clamp");
if (vinterp.opsel & (1 << 3))
fprintf(output, " opsel_hi");
}
}
@@ -714,6 +726,12 @@ aco_print_instr(const Instruction* instr, FILE* output, unsigned flags)
f2f32[i] = vop3p.opsel_hi & (1 << i);
opsel[i] = f2f32[i] && (vop3p.opsel_lo & (1 << i));
}
} else if (instr->isVINTERP_INREG()) {
const VINTERP_inreg_instruction& vinterp = instr->vinterp_inreg();
for (unsigned i = 0; i < MIN2(num_operands, 3); ++i) {
neg[i] = vinterp.neg[i];
opsel[i] = vinterp.opsel & (1 << i);
}
}
for (unsigned i = 0; i < num_operands; ++i) {
if (i)

View File

@@ -503,7 +503,7 @@ get_subdword_operand_stride(amd_gfx_level gfx_level, const aco_ptr<Instruction>&
}
assert(rc.bytes() <= 2);
if (instr->isVALU()) {
if (instr->isVALU() || instr->isVINTERP_INREG()) {
if (can_use_SDWA(gfx_level, instr, false))
return rc.bytes();
if (can_use_opsel(gfx_level, instr->opcode, idx))
@@ -538,13 +538,18 @@ add_subdword_operand(ra_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, uns
return;
assert(rc.bytes() <= 2);
if (instr->isVALU()) {
if (instr->isVALU() || instr->isVINTERP_INREG()) {
/* check if we can use opsel */
if (instr->format == Format::VOP3) {
assert(byte == 2);
instr->vop3().opsel |= 1 << idx;
return;
}
if (instr->isVINTERP_INREG()) {
assert(byte == 2);
instr->vinterp_inreg().opsel |= 1 << idx;
return;
}
if (instr->isVOP3P()) {
assert(byte == 2 && !(instr->vop3p().opsel_lo & (1 << idx)));
instr->vop3p().opsel_lo |= 1 << idx;
@@ -608,7 +613,7 @@ get_subdword_definition_info(Program* program, const aco_ptr<Instruction>& instr
return std::make_pair(4, rc.size() * 4u);
}
if (instr->isVALU() || instr->isVINTRP()) {
if (instr->isVALU() || instr->isVINTRP() || instr->isVINTERP_INREG()) {
assert(rc.bytes() <= 2);
if (can_use_SDWA(gfx_level, instr, false))
@@ -676,7 +681,7 @@ add_subdword_definition(Program* program, aco_ptr<Instruction>& instr, PhysReg r
if (instr->isPseudo())
return;
if (instr->isVALU()) {
if (instr->isVALU() || instr->isVINTERP_INREG()) {
amd_gfx_level gfx_level = program->gfx_level;
assert(instr->definitions[0].bytes() <= 2);
@@ -689,6 +694,11 @@ add_subdword_definition(Program* program, aco_ptr<Instruction>& instr, PhysReg r
assert(can_use_opsel(gfx_level, instr->opcode, -1));
instr->vop3().opsel |= (1 << 3); /* dst in high half */
return;
} else if (instr->isVINTERP_INREG()) {
assert(reg.byte() == 2);
assert(can_use_opsel(gfx_level, instr->opcode, -1));
instr->vinterp_inreg().opsel |= (1 << 3); /* dst in high half */
return;
}
if (instr->opcode == aco_opcode::v_fma_mixlo_f16) {

View File

@@ -281,7 +281,7 @@ validate_ir(Program* program)
instr.get());
}
if (instr->isSALU() || instr->isVALU()) {
if (instr->isSALU() || instr->isVALU() || instr->isVINTERP_INREG()) {
/* check literals */
Operand literal(s1);
for (unsigned i = 0; i < instr->operands.size(); i++) {
@@ -303,7 +303,7 @@ validate_ir(Program* program)
}
/* check num sgprs for VALU */
if (instr->isVALU()) {
if (instr->isVALU() || instr->isVINTERP_INREG()) {
bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
instr->opcode == aco_opcode::v_lshrrev_b64 ||
instr->opcode == aco_opcode::v_ashrrev_i64;
@@ -311,7 +311,8 @@ validate_ir(Program* program)
if (program->gfx_level >= GFX10 && !is_shift64)
const_bus_limit = 2;
uint32_t scalar_mask = instr->isVOP3() || instr->isVOP3P() ? 0x7 : 0x5;
uint32_t scalar_mask =
instr->isVOP3() || instr->isVOP3P() || instr->isVINTERP_INREG() ? 0x7 : 0x5;
if (instr->isSDWA())
scalar_mask = program->gfx_level >= GFX9 ? 0x7 : 0x4;
else if (instr->isDPP())
@@ -898,7 +899,7 @@ get_subdword_bytes_written(Program* program, const aco_ptr<Instruction>& instr,
if (instr->isPseudo())
return gfx_level >= GFX8 ? def.bytes() : def.size() * 4u;
if (instr->isVALU()) {
if (instr->isVALU() || instr->isVINTERP_INREG()) {
assert(def.bytes() <= 2);
if (instr->isSDWA())
return instr->sdwa().dst_sel.size();

View File

@@ -379,3 +379,28 @@ BEGIN_TEST(regalloc.branch_def_phis_at_branch_block)
finish_ra_test(ra_test_policy());
END_TEST
BEGIN_TEST(regalloc.vinterp_fp16)
//>> v1: %in0:v[0], v1: %in1:v[1], v1: %in2:v[2] = p_startpgm
if (!setup_cs("v1 v1 v1", GFX11))
return;
//! v2b: %lo:v[3][0:16], v2b: %hi:v[3][16:32] = p_split_vector %in0:v[0]
Temp lo = bld.tmp(v2b);
Temp hi = bld.tmp(v2b);
bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), inputs[0]);
//! v1: %tmp0:v[1] = v_interp_p10_f16_f32_inreg %lo:v[3][0:16], %in1:v[1], hi(%hi:v[3][16:32])
//! p_unit_test %tmp0:v[1]
Temp tmp0 = bld.vinterp_inreg(aco_opcode::v_interp_p10_f16_f32_inreg, bld.def(v1), lo, inputs[1], hi);
bld.pseudo(aco_opcode::p_unit_test, tmp0);
//! v2b: %tmp1:v[0][16:32] = v_interp_p2_f16_f32_inreg %in0:v[0], %in2:v[2], %tmp0:v[1] opsel_hi
//! v1: %tmp2:v[0] = p_create_vector 0, %tmp1:v[0][16:32]
//! p_unit_test %tmp2:v[0]
Temp tmp1 = bld.vinterp_inreg(aco_opcode::v_interp_p2_f16_f32_inreg, bld.def(v2b), inputs[0], inputs[2], tmp0);
Temp tmp2 = bld.pseudo(aco_opcode::p_create_vector, bld.def(v1), Operand::zero(2), tmp1);
bld.pseudo(aco_opcode::p_unit_test, tmp2);
finish_ra_test(ra_test_policy());
END_TEST