diff --git a/src/compiler/nir/nir_loop_analyze.c b/src/compiler/nir/nir_loop_analyze.c index 995bf902f63..b2630286fd3 100644 --- a/src/compiler/nir/nir_loop_analyze.c +++ b/src/compiler/nir/nir_loop_analyze.c @@ -454,7 +454,10 @@ compute_induction_information(loop_info_state *state) nir_alu_instr *alu = nir_instr_as_alu(src_var->def->parent_instr); /* Check for unsupported alu operations */ - if (alu->op != nir_op_iadd && alu->op != nir_op_fadd) + if (alu->op != nir_op_iadd && alu->op != nir_op_fadd && + alu->op != nir_op_imul && alu->op != nir_op_fmul && + alu->op != nir_op_ishl && alu->op != nir_op_ishr && + alu->op != nir_op_ushr) break; if (nir_op_infos[alu->op].num_inputs == 2) { @@ -995,9 +998,6 @@ calculate_iterations(nir_ssa_def *basis, nir_ssa_def *limit_basis, induction_base_type); } - /* Only variable with these update ops were marked as induction. */ - assert(alu->op == nir_op_iadd || alu->op == nir_op_fadd); - /* do-while loops can increment the starting value before the condition is * checked. e.g. * @@ -1015,8 +1015,6 @@ calculate_iterations(nir_ssa_def *basis, nir_ssa_def *limit_basis, trip_offset = 1; } - assert(nir_src_bit_size(alu->src[0].src) == - nir_src_bit_size(alu->src[1].src)); unsigned bit_size = nir_src_bit_size(alu->src[0].src); /* get_iteration works under assumption that iterator will be @@ -1029,8 +1027,25 @@ calculate_iterations(nir_ssa_def *basis, nir_ssa_def *limit_basis, return 0; } - int iter_int = get_iteration(alu_op, initial, step, limit, bit_size, - execution_mode); + int iter_int; + switch (alu->op) { + case nir_op_iadd: + case nir_op_fadd: + assert(nir_src_bit_size(alu->src[0].src) == + nir_src_bit_size(alu->src[1].src)); + + iter_int = get_iteration(alu_op, initial, step, limit, bit_size, + execution_mode); + break; + case nir_op_imul: + case nir_op_fmul: + case nir_op_ishl: + case nir_op_ishr: + case nir_op_ushr: + return -1; + default: + unreachable("Invalid induction variable increment operation."); + } /* If iter_int is negative the loop is ill-formed or is the conditional is * unsigned with a huge iteration count so don't bother going any further. diff --git a/src/compiler/nir/tests/loop_analyze_tests.cpp b/src/compiler/nir/tests/loop_analyze_tests.cpp index 9b39a6653f9..7cd3fdd5227 100644 --- a/src/compiler/nir/tests/loop_analyze_tests.cpp +++ b/src/compiler/nir/tests/loop_analyze_tests.cpp @@ -1001,6 +1001,40 @@ INOT_COMPARE(ilt_rev) EXPECT_FALSE(loop->info->exact_trip_count_known); \ } +#define KNOWN_COUNT_TEST_INVERT(_init_value, _incr_value, _cond_value, cond, incr, count) \ + TEST_F(nir_loop_analyze_test, incr ## _ ## cond ## _known_count_invert_ ## count) \ + { \ + nir_loop *loop = \ + loop_builder_invert(&b, {.init_value = _init_value, \ + .incr_value = _incr_value, \ + .cond_value = _cond_value, \ + .cond_instr = nir_ ## cond, \ + .incr_instr = nir_ ## incr}); \ + \ + nir_validate_shader(b.shader, "input"); \ + \ + nir_loop_analyze_impl(b.impl, nir_var_all, false); \ + \ + ASSERT_NE((void *)0, loop->info); \ + EXPECT_NE((void *)0, loop->info->limiting_terminator); \ + EXPECT_EQ(count, loop->info->max_trip_count); \ + EXPECT_TRUE(loop->info->exact_trip_count_known); \ + \ + EXPECT_EQ(2, loop->info->num_induction_vars); \ + ASSERT_NE((void *)0, loop->info->induction_vars); \ + \ + const nir_loop_induction_variable *const ivars = \ + loop->info->induction_vars; \ + \ + for (unsigned i = 0; i < loop->info->num_induction_vars; i++) { \ + EXPECT_NE((void *)0, ivars[i].def); \ + ASSERT_NE((void *)0, ivars[i].init_src); \ + EXPECT_TRUE(nir_src_is_const(*ivars[i].init_src)); \ + ASSERT_NE((void *)0, ivars[i].update_src); \ + EXPECT_TRUE(nir_src_is_const(ivars[i].update_src->src)); \ + } \ + } + #define UNKNOWN_COUNT_TEST_INVERT(_init_value, _incr_value, _cond_value, cond, incr) \ TEST_F(nir_loop_analyze_test, incr ## _ ## cond ## _unknown_count_invert) \ { \ @@ -1191,7 +1225,7 @@ UNKNOWN_COUNT_TEST(0x80000000, 0x00000002, 0x00000001, ult, ushr) * i >>= 1; * } */ -UNKNOWN_COUNT_TEST(0x80000000, 0x00000002, 0x00000001, ult_rev, ushr) +KNOWN_COUNT_TEST(0x80000000, 0x00000002, 0x00000001, ult_rev, ushr, 0) /* uint i = 0x80000000; * while (true) { @@ -1201,7 +1235,7 @@ UNKNOWN_COUNT_TEST(0x80000000, 0x00000002, 0x00000001, ult_rev, ushr) * i >>= 1; * } */ -UNKNOWN_COUNT_TEST(0x80000000, 0x80000000, 0x00000001, uge, ushr) +KNOWN_COUNT_TEST(0x80000000, 0x80000000, 0x00000001, uge, ushr, 0) /* uint i = 0x80000000; * while (true) { @@ -1221,7 +1255,7 @@ UNKNOWN_COUNT_TEST(0x80000000, 0x00008000, 0x00000001, uge_rev, ushr) * break; * } */ -UNKNOWN_COUNT_TEST_INVERT(0x80000000, 0x00000001, 0x80000000, ine, ushr) +KNOWN_COUNT_TEST_INVERT(0x80000000, 0x00000001, 0x80000000, ine, ushr, 0) /* uint i = 0x80000000; * while (true) { @@ -1241,7 +1275,7 @@ UNKNOWN_COUNT_TEST_INVERT(0x80000000, 0x00000001, 0x00000000, ieq, ushr) * break; * } */ -UNKNOWN_COUNT_TEST_INVERT(0x80000000, 0x00000001, 0x80000000, ult, ushr) +KNOWN_COUNT_TEST_INVERT(0x80000000, 0x00000001, 0x80000000, ult, ushr, 0) /* uint i = 0xAAAAAAAA; * while (true) { @@ -1251,7 +1285,7 @@ UNKNOWN_COUNT_TEST_INVERT(0x80000000, 0x00000001, 0x80000000, ult, ushr) * break; * } */ -UNKNOWN_COUNT_TEST_INVERT(0xAAAAAAAA, 0x00000001, 0x08000000, ult_rev, ushr) +KNOWN_COUNT_TEST_INVERT(0xAAAAAAAA, 0x00000001, 0x08000000, ult_rev, ushr, 0) /* uint i = 0x80000000; * while (true) { @@ -1261,7 +1295,7 @@ UNKNOWN_COUNT_TEST_INVERT(0xAAAAAAAA, 0x00000001, 0x08000000, ult_rev, ushr) * break; * } */ -UNKNOWN_COUNT_TEST_INVERT(0x80000000, 0x00000001, 0x00000000, uge, ushr) +KNOWN_COUNT_TEST_INVERT(0x80000000, 0x00000001, 0x00000000, uge, ushr, 0) /* uint i = 0x80000000; * while (true) { @@ -1401,7 +1435,7 @@ INFINITE_LOOP_UNKNOWN_COUNT_TEST_INVERT(0x76543210, 0x00000007, 0xffffffff, ige_ * i >>= 1; * } */ -UNKNOWN_COUNT_TEST(0x7fffffff, 0x00000000, 0x00000001, ine, ishr) +KNOWN_COUNT_TEST(0x7fffffff, 0x00000000, 0x00000001, ine, ishr, 0) /* int i = 0x40000000; * while (true) { @@ -1461,7 +1495,7 @@ UNKNOWN_COUNT_TEST(0x12345678, 0x00000001, 0x00000004, ige_rev, ishr) * break; * } */ -UNKNOWN_COUNT_TEST_INVERT(0x7fffffff, 0x00000001, 0x00000000, ine, ishr) +KNOWN_COUNT_TEST_INVERT(0x7fffffff, 0x00000001, 0x00000000, ine, ishr, 0) /* int i = 0x7fffffff; * while (true) { @@ -1721,7 +1755,7 @@ UNKNOWN_COUNT_TEST_INVERT(0x00000001, 0x00000008, 0x01000000, ieq, ishl) * break; * } */ -UNKNOWN_COUNT_TEST_INVERT(0x7fffffff, 0x00000001, 0x00000001, ilt, ishl) +KNOWN_COUNT_TEST_INVERT(0x7fffffff, 0x00000001, 0x00000001, ilt, ishl, 0) /* int i = 0x7fff; * while (true) {