nir: support more loop unrolling for logical operators

Here we support finding loop count when the termination condition
is a logical or.

Acked-by: Pavel Ondračka <pavel.ondracka@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/28998>
This commit is contained in:
Timothy Arceri
2024-04-23 16:22:19 +10:00
committed by Marge Bot
parent 987cf4b47d
commit e25da8d8d7

View File

@@ -1156,25 +1156,25 @@ get_induction_and_limit_vars(nir_scalar cond,
} }
static bool static bool
try_find_trip_count_vars_in_iand(nir_scalar *cond, try_find_trip_count_vars_in_logical_op(nir_scalar *cond,
nir_scalar *ind, nir_scalar *ind,
nir_scalar *limit, nir_scalar *limit,
bool *limit_rhs, bool *limit_rhs,
loop_info_state *state) loop_info_state *state)
{ {
const nir_op alu_op = nir_scalar_alu_op(*cond); const nir_op alu_op = nir_scalar_alu_op(*cond);
assert(alu_op == nir_op_ieq || alu_op == nir_op_inot); bool exit_loop_on_false = alu_op == nir_op_ieq || alu_op == nir_op_inot;
nir_scalar logical_op = exit_loop_on_false ?
nir_scalar iand = nir_scalar_chase_alu_src(*cond, 0); nir_scalar_chase_alu_src(*cond, 0) : *cond;
if (alu_op == nir_op_ieq) { if (alu_op == nir_op_ieq) {
nir_scalar zero = nir_scalar_chase_alu_src(*cond, 1); nir_scalar zero = nir_scalar_chase_alu_src(*cond, 1);
if (!nir_scalar_is_alu(iand) || !nir_scalar_is_const(zero)) { if (!nir_scalar_is_alu(logical_op) || !nir_scalar_is_const(zero)) {
/* Maybe we had it the wrong way, flip things around */ /* Maybe we had it the wrong way, flip things around */
nir_scalar tmp = zero; nir_scalar tmp = zero;
zero = iand; zero = logical_op;
iand = tmp; logical_op = tmp;
/* If we still didn't find what we need then return */ /* If we still didn't find what we need then return */
if (!nir_scalar_is_const(zero)) if (!nir_scalar_is_const(zero))
@@ -1186,10 +1186,11 @@ try_find_trip_count_vars_in_iand(nir_scalar *cond,
return false; return false;
} }
if (!nir_scalar_is_alu(iand)) if (!nir_scalar_is_alu(logical_op))
return false; return false;
if (nir_scalar_alu_op(iand) != nir_op_iand) if ((exit_loop_on_false && (nir_scalar_alu_op(logical_op) != nir_op_iand)) ||
(!exit_loop_on_false && (nir_scalar_alu_op(logical_op) != nir_op_ior)))
return false; return false;
/* Check if iand src is a terminator condition and try get induction var /* Check if iand src is a terminator condition and try get induction var
@@ -1197,7 +1198,7 @@ try_find_trip_count_vars_in_iand(nir_scalar *cond,
*/ */
bool found_induction_var = false; bool found_induction_var = false;
for (unsigned i = 0; i < 2; i++) { for (unsigned i = 0; i < 2; i++) {
nir_scalar src = nir_scalar_chase_alu_src(iand, i); nir_scalar src = nir_scalar_chase_alu_src(logical_op, i);
if (nir_is_terminator_condition_with_two_inputs(src) && if (nir_is_terminator_condition_with_two_inputs(src) &&
get_induction_and_limit_vars(src, ind, limit, limit_rhs, state)) { get_induction_and_limit_vars(src, ind, limit, limit_rhs, state)) {
*cond = src; *cond = src;
@@ -1248,16 +1249,19 @@ find_trip_count(loop_info_state *state, unsigned execution_mode,
bool limit_rhs; bool limit_rhs;
nir_scalar basic_ind = { NULL, 0 }; nir_scalar basic_ind = { NULL, 0 };
nir_scalar limit; nir_scalar limit;
if ((alu_op == nir_op_inot || alu_op == nir_op_ieq) &&
try_find_trip_count_vars_in_iand(&cond, &basic_ind, &limit, if ((alu_op == nir_op_inot || alu_op == nir_op_ieq || alu_op == nir_op_ior) &&
try_find_trip_count_vars_in_logical_op(&cond, &basic_ind, &limit,
&limit_rhs, state)) { &limit_rhs, state)) {
/* The loop is exiting on (x && y) == 0 so we need to get the /* The loop is exiting on (x && y) == 0 so we need to get the
* inverse of x or y (i.e. which ever contained the induction var) in * inverse of x or y (i.e. which ever contained the induction var) in
* order to compute the trip count. * order to compute the trip count.
*/ */
alu_op = nir_scalar_alu_op(cond); if (alu_op == nir_op_inot || alu_op == nir_op_ieq)
invert_cond = !invert_cond; invert_cond = !invert_cond;
alu_op = nir_scalar_alu_op(cond);
trip_count_known = false; trip_count_known = false;
terminator->conditional_instr = cond.def->parent_instr; terminator->conditional_instr = cond.def->parent_instr;
terminator->exact_trip_count_unknown = true; terminator->exact_trip_count_unknown = true;