nv50/ir/nir: implement nir_alu_instr handling

v2: user bitfield_insert instead of bfi
    rework switch helper macros
    remove some lowering code (LoweringHelper is now used for this)
v3: add pack_half_2x16_split
    add unpack_half_2x16_split_x/y
v5: replace first argument with nullptr in loadImm calls
    prefer getSSA over getScratch
v8: fix setting precise modifier for first instruction inside a block
    add guard in case no instruction gets inserted into an empty block
    don't require C++11 features
v9: use CC_NE for integer compares
    convert to C++ style comments
    fix b2f for doubles
    remove macros around nir ops to make it easier to grep them
    add handling for fpow

Signed-off-by: Karol Herbst <kherbst@redhat.com>
This commit is contained in:
Karol Herbst
2017-12-12 21:05:30 +01:00
parent c69b814728
commit 6513c675ad

View File

@@ -114,9 +114,17 @@ private:
std::vector<DataType> getSTypes(nir_alu_instr *);
DataType getSType(nir_src &, bool isFloat, bool isSigned);
operation getOperation(nir_op);
operation preOperationNeeded(nir_op);
int getSubOp(nir_op);
CondCode getCondCode(nir_op);
bool assignSlots();
bool parseNIR();
bool visit(nir_alu_instr *);
bool visit(nir_block *);
bool visit(nir_cf_node *);
bool visit(nir_function *);
@@ -135,6 +143,7 @@ private:
unsigned int curLoopDepth;
BasicBlock *exit;
Value *zero;
union {
struct {
@@ -146,7 +155,10 @@ private:
Converter::Converter(Program *prog, nir_shader *nir, nv50_ir_prog_info *info)
: ConverterCommon(prog, info),
nir(nir),
curLoopDepth(0) {}
curLoopDepth(0)
{
zero = mkImm((uint32_t)0);
}
BasicBlock *
Converter::convert(nir_block *block)
@@ -275,6 +287,191 @@ Converter::getSType(nir_src &src, bool isFloat, bool isSigned)
return ty;
}
operation
Converter::getOperation(nir_op op)
{
switch (op) {
// basic ops with float and int variants
case nir_op_fabs:
case nir_op_iabs:
return OP_ABS;
case nir_op_fadd:
case nir_op_iadd:
return OP_ADD;
case nir_op_fand:
case nir_op_iand:
return OP_AND;
case nir_op_ifind_msb:
case nir_op_ufind_msb:
return OP_BFIND;
case nir_op_fceil:
return OP_CEIL;
case nir_op_fcos:
return OP_COS;
case nir_op_f2f32:
case nir_op_f2f64:
case nir_op_f2i32:
case nir_op_f2i64:
case nir_op_f2u32:
case nir_op_f2u64:
case nir_op_i2f32:
case nir_op_i2f64:
case nir_op_i2i32:
case nir_op_i2i64:
case nir_op_u2f32:
case nir_op_u2f64:
case nir_op_u2u32:
case nir_op_u2u64:
return OP_CVT;
case nir_op_fddx:
case nir_op_fddx_coarse:
case nir_op_fddx_fine:
return OP_DFDX;
case nir_op_fddy:
case nir_op_fddy_coarse:
case nir_op_fddy_fine:
return OP_DFDY;
case nir_op_fdiv:
case nir_op_idiv:
case nir_op_udiv:
return OP_DIV;
case nir_op_fexp2:
return OP_EX2;
case nir_op_ffloor:
return OP_FLOOR;
case nir_op_ffma:
return OP_FMA;
case nir_op_flog2:
return OP_LG2;
case nir_op_fmax:
case nir_op_imax:
case nir_op_umax:
return OP_MAX;
case nir_op_pack_64_2x32_split:
return OP_MERGE;
case nir_op_fmin:
case nir_op_imin:
case nir_op_umin:
return OP_MIN;
case nir_op_fmod:
case nir_op_imod:
case nir_op_umod:
case nir_op_frem:
case nir_op_irem:
return OP_MOD;
case nir_op_fmul:
case nir_op_imul:
case nir_op_imul_high:
case nir_op_umul_high:
return OP_MUL;
case nir_op_fneg:
case nir_op_ineg:
return OP_NEG;
case nir_op_fnot:
case nir_op_inot:
return OP_NOT;
case nir_op_for:
case nir_op_ior:
return OP_OR;
case nir_op_fpow:
return OP_POW;
case nir_op_frcp:
return OP_RCP;
case nir_op_frsq:
return OP_RSQ;
case nir_op_fsat:
return OP_SAT;
case nir_op_feq32:
case nir_op_ieq32:
case nir_op_fge32:
case nir_op_ige32:
case nir_op_uge32:
case nir_op_flt32:
case nir_op_ilt32:
case nir_op_ult32:
case nir_op_fne32:
case nir_op_ine32:
return OP_SET;
case nir_op_ishl:
return OP_SHL;
case nir_op_ishr:
case nir_op_ushr:
return OP_SHR;
case nir_op_fsin:
return OP_SIN;
case nir_op_fsqrt:
return OP_SQRT;
case nir_op_fsub:
case nir_op_isub:
return OP_SUB;
case nir_op_ftrunc:
return OP_TRUNC;
case nir_op_fxor:
case nir_op_ixor:
return OP_XOR;
default:
ERROR("couldn't get operation for op %s\n", nir_op_infos[op].name);
assert(false);
return OP_NOP;
}
}
operation
Converter::preOperationNeeded(nir_op op)
{
switch (op) {
case nir_op_fcos:
case nir_op_fsin:
return OP_PRESIN;
default:
return OP_NOP;
}
}
int
Converter::getSubOp(nir_op op)
{
switch (op) {
case nir_op_imul_high:
case nir_op_umul_high:
return NV50_IR_SUBOP_MUL_HIGH;
default:
return 0;
}
}
CondCode
Converter::getCondCode(nir_op op)
{
switch (op) {
case nir_op_feq32:
case nir_op_ieq32:
return CC_EQ;
case nir_op_fge32:
case nir_op_ige32:
case nir_op_uge32:
return CC_GE;
case nir_op_flt32:
case nir_op_ilt32:
case nir_op_ult32:
return CC_LT;
case nir_op_fne32:
return CC_NEU;
case nir_op_ine32:
return CC_NE;
default:
ERROR("couldn't get CondCode for op %s\n", nir_op_infos[op].name);
assert(false);
return CC_FL;
}
}
Converter::LValues&
Converter::convert(nir_alu_dest *dest)
{
return convert(&dest->dest);
}
Converter::LValues&
Converter::convert(nir_dest *dest)
{
@@ -1314,6 +1511,8 @@ bool
Converter::visit(nir_instr *insn)
{
switch (insn->type) {
case nir_instr_type_alu:
return visit(nir_instr_as_alu(insn));
case nir_instr_type_intrinsic:
return visit(nir_instr_as_intrinsic(insn));
case nir_instr_type_jump:
@@ -1393,6 +1592,367 @@ Converter::visit(nir_load_const_instr *insn)
return true;
}
#define DEFAULT_CHECKS \
if (insn->dest.dest.ssa.num_components > 1) { \
ERROR("nir_alu_instr only supported with 1 component!\n"); \
return false; \
} \
if (insn->dest.write_mask != 1) { \
ERROR("nir_alu_instr only with write_mask of 1 supported!\n"); \
return false; \
}
bool
Converter::visit(nir_alu_instr *insn)
{
const nir_op op = insn->op;
const nir_op_info &info = nir_op_infos[op];
DataType dType = getDType(insn);
const std::vector<DataType> sTypes = getSTypes(insn);
Instruction *oldPos = this->bb->getExit();
switch (op) {
case nir_op_fabs:
case nir_op_iabs:
case nir_op_fadd:
case nir_op_iadd:
case nir_op_fand:
case nir_op_iand:
case nir_op_fceil:
case nir_op_fcos:
case nir_op_fddx:
case nir_op_fddx_coarse:
case nir_op_fddx_fine:
case nir_op_fddy:
case nir_op_fddy_coarse:
case nir_op_fddy_fine:
case nir_op_fdiv:
case nir_op_idiv:
case nir_op_udiv:
case nir_op_fexp2:
case nir_op_ffloor:
case nir_op_ffma:
case nir_op_flog2:
case nir_op_fmax:
case nir_op_imax:
case nir_op_umax:
case nir_op_fmin:
case nir_op_imin:
case nir_op_umin:
case nir_op_fmod:
case nir_op_imod:
case nir_op_umod:
case nir_op_fmul:
case nir_op_imul:
case nir_op_imul_high:
case nir_op_umul_high:
case nir_op_fneg:
case nir_op_ineg:
case nir_op_fnot:
case nir_op_inot:
case nir_op_for:
case nir_op_ior:
case nir_op_pack_64_2x32_split:
case nir_op_fpow:
case nir_op_frcp:
case nir_op_frem:
case nir_op_irem:
case nir_op_frsq:
case nir_op_fsat:
case nir_op_ishr:
case nir_op_ushr:
case nir_op_fsin:
case nir_op_fsqrt:
case nir_op_fsub:
case nir_op_isub:
case nir_op_ftrunc:
case nir_op_ishl:
case nir_op_fxor:
case nir_op_ixor: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
operation preOp = preOperationNeeded(op);
if (preOp != OP_NOP) {
assert(info.num_inputs < 2);
Value *tmp = getSSA(typeSizeof(dType));
Instruction *i0 = mkOp(preOp, dType, tmp);
Instruction *i1 = mkOp(getOperation(op), dType, newDefs[0]);
if (info.num_inputs) {
i0->setSrc(0, getSrc(&insn->src[0]));
i1->setSrc(0, tmp);
}
i1->subOp = getSubOp(op);
} else {
Instruction *i = mkOp(getOperation(op), dType, newDefs[0]);
for (unsigned s = 0u; s < info.num_inputs; ++s) {
i->setSrc(s, getSrc(&insn->src[s]));
}
i->subOp = getSubOp(op);
}
break;
}
case nir_op_ifind_msb:
case nir_op_ufind_msb: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
dType = sTypes[0];
mkOp1(getOperation(op), dType, newDefs[0], getSrc(&insn->src[0]));
break;
}
case nir_op_fround_even: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
mkCvt(OP_CVT, dType, newDefs[0], dType, getSrc(&insn->src[0]))->rnd = ROUND_NI;
break;
}
// convert instructions
case nir_op_f2f32:
case nir_op_f2i32:
case nir_op_f2u32:
case nir_op_i2f32:
case nir_op_i2i32:
case nir_op_u2f32:
case nir_op_u2u32:
case nir_op_f2f64:
case nir_op_f2i64:
case nir_op_f2u64:
case nir_op_i2f64:
case nir_op_i2i64:
case nir_op_u2f64:
case nir_op_u2u64: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
Instruction *i = mkOp1(getOperation(op), dType, newDefs[0], getSrc(&insn->src[0]));
if (op == nir_op_f2i32 || op == nir_op_f2i64 || op == nir_op_f2u32 || op == nir_op_f2u64)
i->rnd = ROUND_Z;
i->sType = sTypes[0];
break;
}
// compare instructions
case nir_op_feq32:
case nir_op_ieq32:
case nir_op_fge32:
case nir_op_ige32:
case nir_op_uge32:
case nir_op_flt32:
case nir_op_ilt32:
case nir_op_ult32:
case nir_op_fne32:
case nir_op_ine32: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
Instruction *i = mkCmp(getOperation(op),
getCondCode(op),
dType,
newDefs[0],
dType,
getSrc(&insn->src[0]),
getSrc(&insn->src[1]));
if (info.num_inputs == 3)
i->setSrc(2, getSrc(&insn->src[2]));
i->sType = sTypes[0];
break;
}
// those are weird ALU ops and need special handling, because
// 1. they are always componend based
// 2. they basically just merge multiple values into one data type
case nir_op_imov:
case nir_op_fmov:
case nir_op_vec2:
case nir_op_vec3:
case nir_op_vec4: {
LValues &newDefs = convert(&insn->dest);
for (LValues::size_type c = 0u; c < newDefs.size(); ++c) {
mkMov(newDefs[c], getSrc(&insn->src[c]), dType);
}
break;
}
// (un)pack
case nir_op_pack_64_2x32: {
LValues &newDefs = convert(&insn->dest);
Instruction *merge = mkOp(OP_MERGE, dType, newDefs[0]);
merge->setSrc(0, getSrc(&insn->src[0], 0));
merge->setSrc(1, getSrc(&insn->src[0], 1));
break;
}
case nir_op_pack_half_2x16_split: {
LValues &newDefs = convert(&insn->dest);
Value *tmpH = getSSA();
Value *tmpL = getSSA();
mkCvt(OP_CVT, TYPE_F16, tmpL, TYPE_F32, getSrc(&insn->src[0]));
mkCvt(OP_CVT, TYPE_F16, tmpH, TYPE_F32, getSrc(&insn->src[1]));
mkOp3(OP_INSBF, TYPE_U32, newDefs[0], tmpH, mkImm(0x1010), tmpL);
break;
}
case nir_op_unpack_half_2x16_split_x:
case nir_op_unpack_half_2x16_split_y: {
LValues &newDefs = convert(&insn->dest);
Instruction *cvt = mkCvt(OP_CVT, TYPE_F32, newDefs[0], TYPE_F16, getSrc(&insn->src[0]));
if (op == nir_op_unpack_half_2x16_split_y)
cvt->subOp = 1;
break;
}
case nir_op_unpack_64_2x32: {
LValues &newDefs = convert(&insn->dest);
mkOp1(OP_SPLIT, dType, newDefs[0], getSrc(&insn->src[0]))->setDef(1, newDefs[1]);
break;
}
case nir_op_unpack_64_2x32_split_x: {
LValues &newDefs = convert(&insn->dest);
mkOp1(OP_SPLIT, dType, newDefs[0], getSrc(&insn->src[0]))->setDef(1, getSSA());
break;
}
case nir_op_unpack_64_2x32_split_y: {
LValues &newDefs = convert(&insn->dest);
mkOp1(OP_SPLIT, dType, getSSA(), getSrc(&insn->src[0]))->setDef(1, newDefs[0]);
break;
}
// special instructions
case nir_op_fsign:
case nir_op_isign: {
DEFAULT_CHECKS;
DataType iType;
if (::isFloatType(dType))
iType = TYPE_F32;
else
iType = TYPE_S32;
LValues &newDefs = convert(&insn->dest);
LValue *val0 = getScratch();
LValue *val1 = getScratch();
mkCmp(OP_SET, CC_GT, iType, val0, dType, getSrc(&insn->src[0]), zero);
mkCmp(OP_SET, CC_LT, iType, val1, dType, getSrc(&insn->src[0]), zero);
if (dType == TYPE_F64) {
mkOp2(OP_SUB, iType, val0, val0, val1);
mkCvt(OP_CVT, TYPE_F64, newDefs[0], iType, val0);
} else if (dType == TYPE_S64 || dType == TYPE_U64) {
mkOp2(OP_SUB, iType, val0, val1, val0);
mkOp2(OP_SHR, iType, val1, val0, loadImm(NULL, 31));
mkOp2(OP_MERGE, dType, newDefs[0], val0, val1);
} else if (::isFloatType(dType))
mkOp2(OP_SUB, iType, newDefs[0], val0, val1);
else
mkOp2(OP_SUB, iType, newDefs[0], val1, val0);
break;
}
case nir_op_fcsel:
case nir_op_b32csel: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
mkCmp(OP_SLCT, CC_NE, dType, newDefs[0], sTypes[0], getSrc(&insn->src[1]), getSrc(&insn->src[2]), getSrc(&insn->src[0]));
break;
}
case nir_op_ibitfield_extract:
case nir_op_ubitfield_extract: {
DEFAULT_CHECKS;
Value *tmp = getSSA();
LValues &newDefs = convert(&insn->dest);
mkOp3(OP_INSBF, dType, tmp, getSrc(&insn->src[2]), loadImm(NULL, 0x808), getSrc(&insn->src[1]));
mkOp2(OP_EXTBF, dType, newDefs[0], getSrc(&insn->src[0]), tmp);
break;
}
case nir_op_bfm: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[0]), loadImm(NULL, 0x808), getSrc(&insn->src[1]));
break;
}
case nir_op_bitfield_insert: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
LValue *temp = getSSA();
mkOp3(OP_INSBF, TYPE_U32, temp, getSrc(&insn->src[3]), mkImm(0x808), getSrc(&insn->src[2]));
mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[1]), temp, getSrc(&insn->src[0]));
break;
}
case nir_op_bit_count: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
mkOp2(OP_POPCNT, dType, newDefs[0], getSrc(&insn->src[0]), getSrc(&insn->src[0]));
break;
}
case nir_op_bitfield_reverse: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
mkOp2(OP_EXTBF, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV;
break;
}
case nir_op_find_lsb: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
Value *tmp = getSSA();
mkOp2(OP_EXTBF, TYPE_U32, tmp, getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV;
mkOp1(OP_BFIND, TYPE_U32, newDefs[0], tmp)->subOp = NV50_IR_SUBOP_BFIND_SAMT;
break;
}
// boolean conversions
case nir_op_b2f32: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
mkOp2(OP_AND, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), loadImm(NULL, 1.0f));
break;
}
case nir_op_b2f64: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
Value *tmp = getSSA(4);
mkOp2(OP_AND, TYPE_U32, tmp, getSrc(&insn->src[0]), loadImm(NULL, 0x3ff00000));
mkOp2(OP_MERGE, TYPE_U64, newDefs[0], loadImm(NULL, 0), tmp);
break;
}
case nir_op_f2b32:
case nir_op_i2b32: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
Value *src1;
if (typeSizeof(sTypes[0]) == 8) {
src1 = loadImm(getSSA(8), 0.0);
} else {
src1 = zero;
}
CondCode cc = op == nir_op_f2b32 ? CC_NEU : CC_NE;
mkCmp(OP_SET, cc, TYPE_U32, newDefs[0], sTypes[0], getSrc(&insn->src[0]), src1);
break;
}
case nir_op_b2i32: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
mkOp2(OP_AND, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), loadImm(NULL, 1));
break;
}
case nir_op_b2i64: {
DEFAULT_CHECKS;
LValues &newDefs = convert(&insn->dest);
LValue *def = getScratch();
mkOp2(OP_AND, TYPE_U32, def, getSrc(&insn->src[0]), loadImm(NULL, 1));
mkOp2(OP_MERGE, TYPE_S64, newDefs[0], def, loadImm(NULL, 0));
break;
}
default:
ERROR("unknown nir_op %s\n", info.name);
return false;
}
if (!oldPos) {
oldPos = this->bb->getEntry();
oldPos->precise = insn->exact;
}
if (unlikely(!oldPos))
return true;
while (oldPos->next) {
oldPos = oldPos->next;
oldPos->precise = insn->exact;
}
oldPos->saturate = insn->dest.saturate;
return true;
}
#undef DEFAULT_CHECKS
bool
Converter::run()
{