nir: Maintain the algebraic automaton's state as we work.
In order to have nir_opt_algebraic be able to do further algebraic work on the output of a replacement, we need to maintain the automaton's state. Reviewed-by: Eric Anholt <eric@anholt.net>
This commit is contained in:

committed by
Eric Anholt

parent
2da4a58ed9
commit
305d1300f9
@@ -38,6 +38,11 @@ struct match_state {
|
||||
bool has_exact_alu;
|
||||
uint8_t comm_op_direction;
|
||||
unsigned variables_seen;
|
||||
|
||||
/* Used for running the automaton on newly-constructed instructions. */
|
||||
struct util_dynarray *states;
|
||||
const struct per_op_table *pass_op_table;
|
||||
|
||||
nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
|
||||
struct hash_table *range_ht;
|
||||
};
|
||||
@@ -46,6 +51,9 @@ static bool
|
||||
match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
|
||||
unsigned num_components, const uint8_t *swizzle,
|
||||
struct match_state *state);
|
||||
static void
|
||||
nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
|
||||
const struct per_op_table *pass_op_table);
|
||||
|
||||
static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
|
||||
|
||||
@@ -490,6 +498,11 @@ construct_value(nir_builder *build,
|
||||
|
||||
nir_builder_instr_insert(build, &alu->instr);
|
||||
|
||||
assert(alu->dest.dest.ssa.index ==
|
||||
util_dynarray_num_elements(state->states, uint16_t));
|
||||
util_dynarray_append(state->states, uint16_t, 0);
|
||||
nir_algebraic_automaton(&alu->instr, state->states, state->pass_op_table);
|
||||
|
||||
nir_alu_src val;
|
||||
val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
|
||||
val.negate = false;
|
||||
@@ -537,6 +550,12 @@ construct_value(nir_builder *build,
|
||||
unreachable("Invalid alu source type");
|
||||
}
|
||||
|
||||
assert(cval->index ==
|
||||
util_dynarray_num_elements(state->states, uint16_t));
|
||||
util_dynarray_append(state->states, uint16_t, 0);
|
||||
nir_algebraic_automaton(cval->parent_instr, state->states,
|
||||
state->pass_op_table);
|
||||
|
||||
nir_alu_src val;
|
||||
val.src = nir_src_for_ssa(cval);
|
||||
val.negate = false;
|
||||
@@ -624,6 +643,8 @@ UNUSED static void dump_value(const nir_search_value *val)
|
||||
nir_ssa_def *
|
||||
nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
|
||||
struct hash_table *range_ht,
|
||||
struct util_dynarray *states,
|
||||
const struct per_op_table *pass_op_table,
|
||||
const nir_search_expression *search,
|
||||
const nir_search_value *replace)
|
||||
{
|
||||
@@ -638,6 +659,7 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
|
||||
state.inexact_match = false;
|
||||
state.has_exact_alu = false;
|
||||
state.range_ht = range_ht;
|
||||
state.pass_op_table = pass_op_table;
|
||||
|
||||
STATIC_ASSERT(sizeof(state.comm_op_direction) * 8 >= NIR_SEARCH_MAX_COMM_OPS);
|
||||
|
||||
@@ -672,6 +694,8 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
|
||||
|
||||
build->cursor = nir_before_instr(&instr->instr);
|
||||
|
||||
state.states = states;
|
||||
|
||||
nir_alu_src val = construct_value(build, replace,
|
||||
instr->dest.dest.ssa.num_components,
|
||||
instr->dest.dest.ssa.bit_size,
|
||||
@@ -682,6 +706,11 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
|
||||
*/
|
||||
nir_ssa_def *ssa_val =
|
||||
nir_mov_alu(build, val, instr->dest.dest.ssa.num_components);
|
||||
if (ssa_val->index == util_dynarray_num_elements(states, uint16_t)) {
|
||||
util_dynarray_append(states, uint16_t, 0);
|
||||
nir_algebraic_automaton(ssa_val->parent_instr, states, pass_op_table);
|
||||
}
|
||||
|
||||
nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
|
||||
|
||||
/* We know this one has no more uses because we just rewrote them all,
|
||||
@@ -694,42 +723,43 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
|
||||
}
|
||||
|
||||
static void
|
||||
nir_algebraic_automaton(nir_block *block, uint16_t *states,
|
||||
nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
|
||||
const struct per_op_table *pass_op_table)
|
||||
{
|
||||
nir_foreach_instr(instr, block) {
|
||||
switch (instr->type) {
|
||||
case nir_instr_type_alu: {
|
||||
nir_alu_instr *alu = nir_instr_as_alu(instr);
|
||||
nir_op op = alu->op;
|
||||
uint16_t search_op = nir_search_op_for_nir_op(op);
|
||||
const struct per_op_table *tbl = &pass_op_table[search_op];
|
||||
if (tbl->num_filtered_states == 0)
|
||||
continue;
|
||||
switch (instr->type) {
|
||||
case nir_instr_type_alu: {
|
||||
nir_alu_instr *alu = nir_instr_as_alu(instr);
|
||||
nir_op op = alu->op;
|
||||
uint16_t search_op = nir_search_op_for_nir_op(op);
|
||||
const struct per_op_table *tbl = &pass_op_table[search_op];
|
||||
if (tbl->num_filtered_states == 0)
|
||||
return;
|
||||
|
||||
/* Calculate the index into the transition table. Note the index
|
||||
* calculated must match the iteration order of Python's
|
||||
* itertools.product(), which was used to emit the transition
|
||||
* table.
|
||||
*/
|
||||
uint16_t index = 0;
|
||||
for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
|
||||
index *= tbl->num_filtered_states;
|
||||
index += tbl->filter[states[alu->src[i].src.ssa->index]];
|
||||
}
|
||||
states[alu->dest.dest.ssa.index] = tbl->table[index];
|
||||
break;
|
||||
/* Calculate the index into the transition table. Note the index
|
||||
* calculated must match the iteration order of Python's
|
||||
* itertools.product(), which was used to emit the transition
|
||||
* table.
|
||||
*/
|
||||
uint16_t index = 0;
|
||||
for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
|
||||
index *= tbl->num_filtered_states;
|
||||
index += tbl->filter[*util_dynarray_element(states, uint16_t,
|
||||
alu->src[i].src.ssa->index)];
|
||||
}
|
||||
*util_dynarray_element(states, uint16_t, alu->dest.dest.ssa.index) =
|
||||
tbl->table[index];
|
||||
break;
|
||||
}
|
||||
|
||||
case nir_instr_type_load_const: {
|
||||
nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
|
||||
states[load_const->def.index] = CONST_STATE;
|
||||
break;
|
||||
}
|
||||
case nir_instr_type_load_const: {
|
||||
nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
|
||||
*util_dynarray_element(states, uint16_t, load_const->def.index) =
|
||||
CONST_STATE;
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -739,7 +769,8 @@ nir_algebraic_block(nir_builder *build, nir_block *block,
|
||||
const bool *condition_flags,
|
||||
const struct transform **transforms,
|
||||
const uint16_t *transform_counts,
|
||||
const uint16_t *states)
|
||||
struct util_dynarray *states,
|
||||
const struct per_op_table *pass_op_table)
|
||||
{
|
||||
bool progress = false;
|
||||
const unsigned execution_mode = build->shader->info.float_controls_execution_mode;
|
||||
@@ -757,12 +788,13 @@ nir_algebraic_block(nir_builder *build, nir_block *block,
|
||||
nir_is_float_control_signed_zero_inf_nan_preserve(execution_mode, bit_size) ||
|
||||
nir_is_denorm_flush_to_zero(execution_mode, bit_size);
|
||||
|
||||
int xform_idx = states[alu->dest.dest.ssa.index];
|
||||
int xform_idx = *util_dynarray_element(states, uint16_t,
|
||||
alu->dest.dest.ssa.index);
|
||||
for (uint16_t i = 0; i < transform_counts[xform_idx]; i++) {
|
||||
const struct transform *xform = &transforms[xform_idx][i];
|
||||
if (condition_flags[xform->condition_offset] &&
|
||||
!(xform->search->inexact && ignore_inexact) &&
|
||||
nir_replace_instr(build, alu, range_ht,
|
||||
nir_replace_instr(build, alu, range_ht, states, pass_op_table,
|
||||
xform->search, xform->replace)) {
|
||||
_mesa_hash_table_clear(range_ht, NULL);
|
||||
progress = true;
|
||||
@@ -790,22 +822,27 @@ nir_algebraic_impl(nir_function_impl *impl,
|
||||
* state 0 is the default state, which means we don't have to visit
|
||||
* anything other than constants and ALU instructions.
|
||||
*/
|
||||
uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states));
|
||||
struct util_dynarray states = {0};
|
||||
if (!util_dynarray_resize(&states, uint16_t, impl->ssa_alloc))
|
||||
return false;
|
||||
memset(states.data, 0, states.size);
|
||||
|
||||
struct hash_table *range_ht = _mesa_pointer_hash_table_create(NULL);
|
||||
|
||||
nir_foreach_block(block, impl) {
|
||||
nir_algebraic_automaton(block, states, pass_op_table);
|
||||
nir_foreach_instr(instr, block) {
|
||||
nir_algebraic_automaton(instr, &states, pass_op_table);
|
||||
}
|
||||
}
|
||||
|
||||
nir_foreach_block_reverse(block, impl) {
|
||||
progress |= nir_algebraic_block(&build, block, range_ht, condition_flags,
|
||||
transforms, transform_counts,
|
||||
states);
|
||||
&states, pass_op_table);
|
||||
}
|
||||
|
||||
ralloc_free(range_ht);
|
||||
free(states);
|
||||
util_dynarray_fini(&states);
|
||||
|
||||
if (progress) {
|
||||
nir_metadata_preserve(impl, nir_metadata_block_index |
|
||||
|
Reference in New Issue
Block a user