nir/loop_analyze: Use try_eval_const_alu and induction variable basis info

This dramatically simplifies will_break_on_first_iteration, and, much
more importantly, makes it significantly more flexible. It is now
possible to handle loops with more complex exit condition and other
kinds of increment operations.

Reviewed-by: Timothy Arceri <tarceri@itsqueeze.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/3445>
This commit is contained in:
Ian Romanick
2023-02-13 17:33:29 -08:00
committed by Marge Bot
parent 99a7a6648d
commit bc170e895f

View File

@@ -905,40 +905,20 @@ get_iteration(nir_op cond_op, nir_const_value initial, nir_const_value step,
}
static bool
will_break_on_first_iteration(nir_const_value step,
nir_alu_type induction_base_type,
unsigned trip_offset,
nir_op cond_op, unsigned bit_size,
nir_const_value initial,
nir_const_value limit,
bool limit_rhs, bool invert_cond,
unsigned execution_mode)
will_break_on_first_iteration(nir_alu_instr *cond_alu, nir_ssa_def *basis,
nir_ssa_def *limit_basis,
nir_const_value initial, nir_const_value limit,
bool invert_cond, unsigned execution_mode)
{
if (trip_offset == 1) {
nir_op add_op;
switch (induction_base_type) {
case nir_type_float:
add_op = nir_op_fadd;
break;
case nir_type_int:
case nir_type_uint:
add_op = nir_op_iadd;
break;
default:
unreachable("Unhandled induction variable base type!");
}
initial = eval_const_binop(add_op, bit_size, initial, step,
execution_mode);
}
nir_const_value *src[2];
src[limit_rhs ? 0 : 1] = &initial;
src[limit_rhs ? 1 : 0] = &limit;
/* Evaluate the loop exit condition */
nir_const_value result;
nir_eval_const_opcode(cond_op, &result, 1, bit_size, src, execution_mode);
const nir_ssa_def *originals[2] = { basis, limit_basis };
const nir_const_value *replacements[2] = { &initial, &limit };
ASSERTED bool success = try_eval_const_alu(&result, cond_alu, originals,
replacements, 2, execution_mode);
assert(success);
return invert_cond ? !result.b : result.b;
}
@@ -993,7 +973,8 @@ test_iterations(int32_t iter_int, nir_const_value step,
}
static int
calculate_iterations(nir_const_value initial, nir_const_value step,
calculate_iterations(nir_ssa_def *basis, nir_ssa_def *limit_basis,
nir_const_value initial, nir_const_value step,
nir_const_value limit, nir_alu_instr *alu,
nir_ssa_scalar cond, nir_op alu_op, bool limit_rhs,
bool invert_cond, unsigned execution_mode)
@@ -1043,10 +1024,8 @@ calculate_iterations(nir_const_value initial, nir_const_value step,
* however if the loop condition is false on the first iteration
* get_iteration's assumption is broken. Handle such loops first.
*/
if (will_break_on_first_iteration(step, induction_base_type, trip_offset,
alu_op, bit_size, initial,
limit, limit_rhs, invert_cond,
execution_mode)) {
if (will_break_on_first_iteration(cond_alu, basis, limit_basis, initial,
limit, invert_cond, execution_mode)) {
return 0;
}
@@ -1329,7 +1308,8 @@ find_trip_count(loop_info_state *state, unsigned execution_mode)
nir_const_value initial_val = nir_ssa_scalar_as_const_value(initial_s);
nir_const_value step_val = nir_ssa_scalar_as_const_value(alu_s);
int iterations = calculate_iterations(initial_val, step_val, limit_val,
int iterations = calculate_iterations(lv->basis, limit.def,
initial_val, step_val, limit_val,
nir_instr_as_alu(lv->update_src->src.parent_instr),
cond,
alu_op, limit_rhs,