nir/algebraic: add ignore_exact() wrapper
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com> Reviewed-by: Timur Kristóf <timur.kristof@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13436>
This commit is contained in:
@@ -98,13 +98,33 @@ class VarSet(object):
|
|||||||
def lock(self):
|
def lock(self):
|
||||||
self.immutable = True
|
self.immutable = True
|
||||||
|
|
||||||
|
class SearchExpression(object):
|
||||||
|
def __init__(self, expr):
|
||||||
|
self.opcode = expr[0]
|
||||||
|
self.sources = expr[1:]
|
||||||
|
self.ignore_exact = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(val):
|
||||||
|
if isinstance(val, tuple):
|
||||||
|
return SearchExpression(val)
|
||||||
|
else:
|
||||||
|
assert(isinstance(val, SearchExpression))
|
||||||
|
return val
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
l = [self.opcode, *self.sources]
|
||||||
|
if self.ignore_exact:
|
||||||
|
l.append('ignore_exact')
|
||||||
|
return repr((*l,))
|
||||||
|
|
||||||
class Value(object):
|
class Value(object):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(val, name_base, varset, algebraic_pass):
|
def create(val, name_base, varset, algebraic_pass):
|
||||||
if isinstance(val, bytes):
|
if isinstance(val, bytes):
|
||||||
val = val.decode('utf-8')
|
val = val.decode('utf-8')
|
||||||
|
|
||||||
if isinstance(val, tuple):
|
if isinstance(val, tuple) or isinstance(val, SearchExpression):
|
||||||
return Expression(val, name_base, varset, algebraic_pass)
|
return Expression(val, name_base, varset, algebraic_pass)
|
||||||
elif isinstance(val, Expression):
|
elif isinstance(val, Expression):
|
||||||
return val
|
return val
|
||||||
@@ -185,7 +205,9 @@ class Value(object):
|
|||||||
${val.cond_index},
|
${val.cond_index},
|
||||||
${val.swizzle()},
|
${val.swizzle()},
|
||||||
% elif isinstance(val, Expression):
|
% elif isinstance(val, Expression):
|
||||||
${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
|
${'true' if val.inexact else 'false'},
|
||||||
|
${'true' if val.exact else 'false'},
|
||||||
|
${'true' if val.ignore_exact else 'false'},
|
||||||
${val.c_opcode()},
|
${val.c_opcode()},
|
||||||
${val.comm_expr_idx}, ${val.comm_exprs},
|
${val.comm_expr_idx}, ${val.comm_exprs},
|
||||||
{ ${', '.join(src.array_index for src in val.sources)} },
|
{ ${', '.join(src.array_index for src in val.sources)} },
|
||||||
@@ -339,15 +361,17 @@ _opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bit
|
|||||||
class Expression(Value):
|
class Expression(Value):
|
||||||
def __init__(self, expr, name_base, varset, algebraic_pass):
|
def __init__(self, expr, name_base, varset, algebraic_pass):
|
||||||
Value.__init__(self, expr, name_base, "expression")
|
Value.__init__(self, expr, name_base, "expression")
|
||||||
assert isinstance(expr, tuple)
|
|
||||||
|
|
||||||
m = _opcode_re.match(expr[0])
|
expr = SearchExpression.create(expr)
|
||||||
|
|
||||||
|
m = _opcode_re.match(expr.opcode)
|
||||||
assert m and m.group('opcode') is not None
|
assert m and m.group('opcode') is not None
|
||||||
|
|
||||||
self.opcode = m.group('opcode')
|
self.opcode = m.group('opcode')
|
||||||
self._bit_size = int(m.group('bits')) if m.group('bits') else None
|
self._bit_size = int(m.group('bits')) if m.group('bits') else None
|
||||||
self.inexact = m.group('inexact') is not None
|
self.inexact = m.group('inexact') is not None
|
||||||
self.exact = m.group('exact') is not None
|
self.exact = m.group('exact') is not None
|
||||||
|
self.ignore_exact = expr.ignore_exact
|
||||||
self.cond = m.group('cond')
|
self.cond = m.group('cond')
|
||||||
|
|
||||||
assert not self.inexact or not self.exact, \
|
assert not self.inexact or not self.exact, \
|
||||||
@@ -372,7 +396,7 @@ class Expression(Value):
|
|||||||
self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond)
|
self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond)
|
||||||
|
|
||||||
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass)
|
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass)
|
||||||
for (i, src) in enumerate(expr[1:]) ]
|
for (i, src) in enumerate(expr.sources) ]
|
||||||
|
|
||||||
# nir_search_expression::srcs is hard-coded to 4
|
# nir_search_expression::srcs is hard-coded to 4
|
||||||
assert len(self.sources) <= 4
|
assert len(self.sources) <= 4
|
||||||
@@ -1235,3 +1259,9 @@ class AlgebraicPass(object):
|
|||||||
variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]),
|
variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]),
|
||||||
get_c_opcode=get_c_opcode,
|
get_c_opcode=get_c_opcode,
|
||||||
itertools=itertools)
|
itertools=itertools)
|
||||||
|
|
||||||
|
# The replacement expression isn't necessarily exact if the search expression is exact.
|
||||||
|
def ignore_exact(*expr):
|
||||||
|
expr = SearchExpression.create(expr)
|
||||||
|
expr.ignore_exact = True
|
||||||
|
return expr
|
||||||
|
@@ -41,6 +41,8 @@ e = 'e'
|
|||||||
signed_zero_inf_nan_preserve_16 = 'nir_is_float_control_signed_zero_inf_nan_preserve(info->float_controls_execution_mode, 16)'
|
signed_zero_inf_nan_preserve_16 = 'nir_is_float_control_signed_zero_inf_nan_preserve(info->float_controls_execution_mode, 16)'
|
||||||
signed_zero_inf_nan_preserve_32 = 'nir_is_float_control_signed_zero_inf_nan_preserve(info->float_controls_execution_mode, 32)'
|
signed_zero_inf_nan_preserve_32 = 'nir_is_float_control_signed_zero_inf_nan_preserve(info->float_controls_execution_mode, 32)'
|
||||||
|
|
||||||
|
ignore_exact = nir_algebraic.ignore_exact
|
||||||
|
|
||||||
# Written in the form (<search>, <replace>) where <search> is an expression
|
# Written in the form (<search>, <replace>) where <search> is an expression
|
||||||
# and <replace> is either an expression or a value. An expression is
|
# and <replace> is either an expression or a value. An expression is
|
||||||
# defined as a tuple of the form ([~]<op>, <src0>, <src1>, <src2>, <src3>)
|
# defined as a tuple of the form ([~]<op>, <src0>, <src1>, <src2>, <src3>)
|
||||||
|
@@ -408,7 +408,7 @@ match_expression(const nir_algebraic_table *table, const nir_search_expression *
|
|||||||
return false;
|
return false;
|
||||||
|
|
||||||
state->inexact_match = expr->inexact || state->inexact_match;
|
state->inexact_match = expr->inexact || state->inexact_match;
|
||||||
state->has_exact_alu = instr->exact || state->has_exact_alu;
|
state->has_exact_alu = (instr->exact && !expr->ignore_exact) || state->has_exact_alu;
|
||||||
if (state->inexact_match && state->has_exact_alu)
|
if (state->inexact_match && state->has_exact_alu)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
@@ -142,8 +142,11 @@ typedef struct {
|
|||||||
/** In a replacement, requests that the instruction be marked exact. */
|
/** In a replacement, requests that the instruction be marked exact. */
|
||||||
bool exact : 1;
|
bool exact : 1;
|
||||||
|
|
||||||
|
/** Don't make the replacement exact if the search expression is exact. */
|
||||||
|
bool ignore_exact : 1;
|
||||||
|
|
||||||
/* One of nir_op or nir_search_op */
|
/* One of nir_op or nir_search_op */
|
||||||
uint16_t opcode : 14;
|
uint16_t opcode : 13;
|
||||||
|
|
||||||
/* Commutative expression index. This is assigned by opt_algebraic.py when
|
/* Commutative expression index. This is assigned by opt_algebraic.py when
|
||||||
* search structures are constructed and is a unique (to this structure)
|
* search structures are constructed and is a unique (to this structure)
|
||||||
|
Reference in New Issue
Block a user