diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index 57887947912..c9daccdc59c 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -109,7 +109,7 @@ class Value(object): elif isinstance(val, Expression): return val elif isinstance(val, str): - return Variable(val, name_base, varset) + return Variable(val, name_base, varset, algebraic_pass) elif isinstance(val, (bool, float, int)): return Constant(val, name_base) @@ -182,7 +182,7 @@ class Value(object): ${val.index}, /* ${val.var_name} */ ${'true' if val.is_constant else 'false'}, ${val.type() or 'nir_type_invalid' }, - ${val.cond if val.cond else 'NULL'}, + ${val.cond_index}, ${val.swizzle()}, % elif isinstance(val, Expression): ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'}, @@ -269,7 +269,7 @@ _var_name_re = re.compile(r"(?P#)?(?P\w+)" r"$") class Variable(Value): - def __init__(self, val, name, varset): + def __init__(self, val, name, varset, algebraic_pass): Value.__init__(self, val, name, "variable") m = _var_name_re.match(val) @@ -286,7 +286,7 @@ class Variable(Value): assert self.var_name != 'False' self.is_constant = m.group('const') is not None - self.cond = m.group('cond') + self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond')) self.required_type = m.group('type') self._bit_size = int(m.group('bits')) if m.group('bits') else None self.swiz = m.group('swiz') @@ -1064,6 +1064,14 @@ static const nir_search_expression_cond ${pass_name}_expression_cond[] = { }; % endif +% if variable_cond: +static const nir_search_variable_cond ${pass_name}_variable_cond[] = { +% for cond in variable_cond: + ${cond[0]}, +% endfor +}; +% endif + % for state_id, state_xforms in enumerate(automaton.state_patterns): % if state_xforms: # avoid emitting a 0-length array for MSVC static const struct transform ${pass_name}_state${state_id}_xforms[] = { @@ -1124,6 +1132,7 @@ static const nir_algebraic_table ${pass_name}_table = { .pass_op_table = ${pass_name}_pass_op_table, .values = ${pass_name}_values, .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" }, + .variable_cond = ${ pass_name + "_variable_cond" if variable_cond else "NULL" }, }; bool @@ -1159,6 +1168,7 @@ class AlgebraicPass(object): self.opcode_xforms = defaultdict(lambda : []) self.pass_name = pass_name self.expression_cond = {} + self.variable_cond = {} error = False @@ -1222,5 +1232,6 @@ class AlgebraicPass(object): condition_list=condition_list, automaton=self.automaton, expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]), + variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]), get_c_opcode=get_c_opcode, itertools=itertools) diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index cbe7ca22a7c..4e20de1397e 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -315,8 +315,8 @@ match_value(const nir_algebraic_table *table, instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const) return false; - if (var->cond && !var->cond(state->range_ht, instr, - src, num_components, new_swizzle)) + if (var->cond_index != -1 && !table->variable_cond[var->cond_index](state->range_ht, instr, + src, num_components, new_swizzle)) return false; if (var->type != nir_type_invalid && diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h index c0bb6129bb6..421ee5f4662 100644 --- a/src/compiler/nir/nir_search.h +++ b/src/compiler/nir/nir_search.h @@ -88,15 +88,14 @@ typedef struct { */ nir_alu_type type; - /** Optional condition fxn ptr + /** Optional table->variable_cond[] fxn ptr index * * This is only allowed in search expressions, and allows additional * constraints to be placed on the match. Typically used for 'is_constant' * variables to require, for example, power-of-two in order for the search * to match. */ - bool (*cond)(struct hash_table *range_ht, const nir_alu_instr *instr, - unsigned src, unsigned num_components, const uint8_t *swizzle); + int16_t cond_index; /** Swizzle (for replace only) */ uint8_t swizzle[NIR_MAX_VEC_COMPONENTS]; @@ -190,6 +189,10 @@ typedef union { } nir_search_value_union; typedef bool (*nir_search_expression_cond)(nir_alu_instr *instr); +typedef bool (*nir_search_variable_cond)(struct hash_table *range_ht, + const nir_alu_instr *instr, + unsigned src, unsigned num_components, + const uint8_t *swizzle); /* Generated data table for an algebraic optimization pass. */ typedef struct { @@ -203,6 +206,12 @@ typedef struct { * nir_search_expression->cond. */ const nir_search_expression_cond *expression_cond; + + /** + * Array of condition functions for variables, referenced by + * nir_search_variable->cond. + */ + const nir_search_variable_cond *variable_cond; } nir_algebraic_table; /* Note: these must match the start states created in