nir: Add support for 2src_commutative ops that have 3 sources
v2: Instead of handling 3 sources as a special case, generalize with loops to N sources. Suggested by Jason. v3: Further generalize by only checking that number of sources is >= 2. Suggested by Jason. Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
This commit is contained in:
@@ -796,12 +796,12 @@ class TreeAutomaton(object):
|
||||
self.opcodes = self.IndexMap()
|
||||
|
||||
def get_item(opcode, children, pattern=None):
|
||||
commutative = len(children) == 2 \
|
||||
commutative = len(children) >= 2 \
|
||||
and "2src_commutative" in opcodes[opcode].algebraic_properties
|
||||
item = self.items.setdefault((opcode, children),
|
||||
self.Item(opcode, children))
|
||||
if commutative:
|
||||
self.items[opcode, (children[1], children[0])] = item
|
||||
self.items[opcode, (children[1], children[0]) + children[2:]] = item
|
||||
if pattern is not None:
|
||||
item.patterns.append(pattern)
|
||||
return item
|
||||
|
@@ -57,7 +57,8 @@ hash_alu(uint32_t hash, const nir_alu_instr *instr)
|
||||
/* We explicitly don't hash instr->dest.dest.exact */
|
||||
|
||||
if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_2SRC_COMMUTATIVE) {
|
||||
assert(nir_op_infos[instr->op].num_inputs == 2);
|
||||
assert(nir_op_infos[instr->op].num_inputs >= 2);
|
||||
|
||||
uint32_t hash0 = hash_alu_src(hash, &instr->src[0],
|
||||
nir_ssa_alu_instr_src_components(instr, 0));
|
||||
uint32_t hash1 = hash_alu_src(hash, &instr->src[1],
|
||||
@@ -69,6 +70,11 @@ hash_alu(uint32_t hash, const nir_alu_instr *instr)
|
||||
* collision. Either addition or multiplication will also work.
|
||||
*/
|
||||
hash = hash0 * hash1;
|
||||
|
||||
for (unsigned i = 2; i < nir_op_infos[instr->op].num_inputs; i++) {
|
||||
hash = hash_alu_src(hash, &instr->src[i],
|
||||
nir_ssa_alu_instr_src_components(instr, i));
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
|
||||
hash = hash_alu_src(hash, &instr->src[i],
|
||||
@@ -529,11 +535,16 @@ nir_instrs_equal(const nir_instr *instr1, const nir_instr *instr2)
|
||||
/* We explicitly don't hash instr->dest.dest.exact */
|
||||
|
||||
if (nir_op_infos[alu1->op].algebraic_properties & NIR_OP_IS_2SRC_COMMUTATIVE) {
|
||||
assert(nir_op_infos[alu1->op].num_inputs == 2);
|
||||
return (nir_alu_srcs_equal(alu1, alu2, 0, 0) &&
|
||||
nir_alu_srcs_equal(alu1, alu2, 1, 1)) ||
|
||||
(nir_alu_srcs_equal(alu1, alu2, 0, 1) &&
|
||||
nir_alu_srcs_equal(alu1, alu2, 1, 0));
|
||||
if ((!nir_alu_srcs_equal(alu1, alu2, 0, 0) ||
|
||||
!nir_alu_srcs_equal(alu1, alu2, 1, 1)) &&
|
||||
(!nir_alu_srcs_equal(alu1, alu2, 0, 1) ||
|
||||
!nir_alu_srcs_equal(alu1, alu2, 1, 0)))
|
||||
return false;
|
||||
|
||||
for (unsigned i = 2; i < nir_op_infos[alu1->op].num_inputs; i++) {
|
||||
if (!nir_alu_srcs_equal(alu1, alu2, i, i))
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) {
|
||||
if (!nir_alu_srcs_equal(alu1, alu2, i, i))
|
||||
|
@@ -408,7 +408,11 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
|
||||
|
||||
bool matched = true;
|
||||
for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
|
||||
if (!match_value(expr->srcs[i], instr, i ^ comm_op_flip,
|
||||
/* 2src_commutative instructions that have 3 sources are only commutative
|
||||
* in the first two sources. Source 2 is always source 2.
|
||||
*/
|
||||
if (!match_value(expr->srcs[i], instr,
|
||||
i < 2 ? i ^ comm_op_flip : i,
|
||||
num_components, swizzle, state)) {
|
||||
matched = false;
|
||||
break;
|
||||
|
Reference in New Issue
Block a user