nir/search: Search for all combinations of commutative ops
Consider the following search expression and NIR sequence: ('iadd', ('imul', a, b), b) ssa_2 = imul ssa_0, ssa_1 ssa_3 = iadd ssa_2, ssa_0 The current algorithm is greedy and, the moment the imul finds a match, it commits those variable names and returns success. In the above example, it maps a -> ssa_0 and b -> ssa_1. When we then try to match the iadd, it sees that ssa_0 is not b and fails to match. The iadd match will attempt to flip itself and try again (which won't work) but it cannot ask the imul to try a flipped match. This commit instead counts the number of commutative ops in each expression and assigns an index to each. It then does a loop and loops over the full combinatorial matrix of commutative operations. In order to keep things sane, we limit it to at most 4 commutative operations (16 combinations). There is only one optimization in opt_algebraic that goes over this limit and it's the bitfieldReverse detection for some UE4 demo. Shader-db results on Kaby Lake: total instructions in shared programs: 15310125 -> 15302469 (-0.05%) instructions in affected programs: 1797123 -> 1789467 (-0.43%) helped: 6751 HURT: 2264 total cycles in shared programs: 357346617 -> 357202526 (-0.04%) cycles in affected programs: 15931005 -> 15786914 (-0.90%) helped: 6024 HURT: 3436 total loops in shared programs: 4360 -> 4360 (0.00%) loops in affected programs: 0 -> 0 helped: 0 HURT: 0 total spills in shared programs: 23675 -> 23666 (-0.04%) spills in affected programs: 235 -> 226 (-3.83%) helped: 5 HURT: 1 total fills in shared programs: 32040 -> 32032 (-0.02%) fills in affected programs: 190 -> 182 (-4.21%) helped: 6 HURT: 2 LOST: 18 GAINED: 5 Reviewed-by: Thomas Helland <thomashelland90@gmail.com>
This commit is contained in:

committed by
Jason Ekstrand

parent
48e48b8560
commit
50f3535d1f
@@ -114,6 +114,7 @@ static const ${val.c_type} ${val.name} = {
|
||||
${val.cond if val.cond else 'NULL'},
|
||||
% elif isinstance(val, Expression):
|
||||
${'true' if val.inexact else 'false'},
|
||||
${val.comm_expr_idx}, ${val.comm_exprs},
|
||||
${val.c_opcode()},
|
||||
{ ${', '.join(src.c_ptr for src in val.sources)} },
|
||||
${val.cond if val.cond else 'NULL'},
|
||||
@@ -307,6 +308,25 @@ class Expression(Value):
|
||||
'Expression cannot use an unsized conversion opcode with ' \
|
||||
'an explicit size; that\'s silly.'
|
||||
|
||||
self.__index_comm_exprs(0)
|
||||
|
||||
def __index_comm_exprs(self, base_idx):
|
||||
"""Recursively count and index commutative expressions
|
||||
"""
|
||||
self.comm_exprs = 0
|
||||
if self.opcode not in conv_opcode_types and \
|
||||
"commutative" in opcodes[self.opcode].algebraic_properties:
|
||||
self.comm_expr_idx = base_idx
|
||||
self.comm_exprs += 1
|
||||
else:
|
||||
self.comm_expr_idx = -1
|
||||
|
||||
for s in self.sources:
|
||||
if isinstance(s, Expression):
|
||||
s.__index_comm_exprs(base_idx + self.comm_exprs)
|
||||
self.comm_exprs += s.comm_exprs
|
||||
|
||||
return self.comm_exprs
|
||||
|
||||
def c_opcode(self):
|
||||
if self.opcode in conv_opcode_types:
|
||||
|
@@ -30,9 +30,12 @@
|
||||
#include "nir_builder.h"
|
||||
#include "util/half_float.h"
|
||||
|
||||
#define NIR_SEARCH_MAX_COMM_OPS 4
|
||||
|
||||
struct match_state {
|
||||
bool inexact_match;
|
||||
bool has_exact_alu;
|
||||
uint8_t comm_op_direction;
|
||||
unsigned variables_seen;
|
||||
nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
|
||||
};
|
||||
@@ -349,41 +352,25 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
|
||||
}
|
||||
}
|
||||
|
||||
/* Stash off the current variables_seen bitmask. This way we can
|
||||
* restore it prior to matching in the commutative case below.
|
||||
/* If this is a commutative expression and it's one of the first few, look
|
||||
* up its direction for the current search operation. We'll use that value
|
||||
* to possibly flip the sources for the match.
|
||||
*/
|
||||
unsigned variables_seen_stash = state->variables_seen;
|
||||
unsigned comm_op_flip =
|
||||
(expr->comm_expr_idx >= 0 &&
|
||||
expr->comm_expr_idx < NIR_SEARCH_MAX_COMM_OPS) ?
|
||||
((state->comm_op_direction >> expr->comm_expr_idx) & 1) : 0;
|
||||
|
||||
bool matched = true;
|
||||
for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
|
||||
if (!match_value(expr->srcs[i], instr, i, num_components,
|
||||
swizzle, state)) {
|
||||
if (!match_value(expr->srcs[i], instr, i ^ comm_op_flip,
|
||||
num_components, swizzle, state)) {
|
||||
matched = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (matched)
|
||||
return true;
|
||||
|
||||
if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
|
||||
assert(nir_op_infos[instr->op].num_inputs == 2);
|
||||
|
||||
/* Restore the variables_seen bitmask. If we don't do this, then we
|
||||
* could end up with an erroneous failure due to variables found in the
|
||||
* first match attempt above not matching those in the second.
|
||||
*/
|
||||
state->variables_seen = variables_seen_stash;
|
||||
|
||||
if (!match_value(expr->srcs[0], instr, 1, num_components,
|
||||
swizzle, state))
|
||||
return false;
|
||||
|
||||
return match_value(expr->srcs[1], instr, 0, num_components,
|
||||
swizzle, state);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return matched;
|
||||
}
|
||||
|
||||
static unsigned
|
||||
@@ -513,10 +500,26 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
|
||||
struct match_state state;
|
||||
state.inexact_match = false;
|
||||
state.has_exact_alu = false;
|
||||
|
||||
unsigned comm_expr_combinations =
|
||||
1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS);
|
||||
|
||||
bool found = false;
|
||||
for (unsigned comb = 0; comb < comm_expr_combinations; comb++) {
|
||||
/* The bitfield of directions is just the current iteration. Hooray for
|
||||
* binary.
|
||||
*/
|
||||
state.comm_op_direction = comb;
|
||||
state.variables_seen = 0;
|
||||
|
||||
if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
|
||||
swizzle, &state))
|
||||
if (match_expression(search, instr,
|
||||
instr->dest.dest.ssa.num_components,
|
||||
swizzle, &state)) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found)
|
||||
return NULL;
|
||||
|
||||
build->cursor = nir_before_instr(&instr->instr);
|
||||
|
@@ -132,6 +132,18 @@ typedef struct {
|
||||
*/
|
||||
bool inexact;
|
||||
|
||||
/* Commutative expression index. This is assigned by opt_algebraic.py when
|
||||
* search structures are constructed and is a unique (to this structure)
|
||||
* index within the commutative operation bitfield used for searching for
|
||||
* all combinations of expressions containing commutative operations.
|
||||
*/
|
||||
int8_t comm_expr_idx;
|
||||
|
||||
/* Number of commutative expressions in this expression including this one
|
||||
* (if it is commutative).
|
||||
*/
|
||||
uint8_t comm_exprs;
|
||||
|
||||
/* One of nir_op or nir_search_op */
|
||||
uint16_t opcode;
|
||||
const nir_search_value *srcs[4];
|
||||
|
Reference in New Issue
Block a user