nir/constant_expressions: Refactor helper functions
Apart from avoiding some unneeded size cases, this shouldn't have any actual functional impact. Reviewed-by: Dylan Baker <dylan@pnwbakers.com> Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
This commit is contained in:
@@ -1,16 +1,18 @@
|
||||
|
||||
import re
|
||||
|
||||
type_split_re = re.compile(r'(?P<type>[a-z]+)(?P<bits>\d+)')
|
||||
|
||||
def type_has_size(type_):
|
||||
return type_[-1:].isdigit()
|
||||
|
||||
def type_size(type_):
|
||||
assert type_has_size(type_)
|
||||
return int(type_split_re.match(type_).group('bits'))
|
||||
|
||||
def type_sizes(type_):
|
||||
if type_.endswith("8"):
|
||||
return [8]
|
||||
elif type_.endswith("16"):
|
||||
return [16]
|
||||
elif type_.endswith("32"):
|
||||
return [32]
|
||||
elif type_.endswith("64"):
|
||||
return [64]
|
||||
if type_has_size(type_):
|
||||
return [type_size(type_)]
|
||||
else:
|
||||
return [32, 64]
|
||||
|
||||
@@ -19,23 +21,23 @@ def type_add_size(type_, size):
|
||||
return type_
|
||||
return type_ + str(size)
|
||||
|
||||
def op_bit_sizes(op):
|
||||
sizes = set([8, 16, 32, 64])
|
||||
if not type_has_size(op.output_type):
|
||||
sizes = sizes.intersection(set(type_sizes(op.output_type)))
|
||||
for input_type in op.input_types:
|
||||
if not type_has_size(input_type):
|
||||
sizes = sizes.intersection(set(type_sizes(input_type)))
|
||||
return sorted(list(sizes))
|
||||
|
||||
def get_const_field(type_):
|
||||
if type_ == "int32":
|
||||
return "i32"
|
||||
if type_ == "uint32":
|
||||
return "u32"
|
||||
if type_ == "int64":
|
||||
return "i64"
|
||||
if type_ == "uint64":
|
||||
return "u64"
|
||||
if type_ == "bool32":
|
||||
return "u32"
|
||||
if type_ == "float32":
|
||||
return "f32"
|
||||
if type_ == "float64":
|
||||
return "f64"
|
||||
raise Exception(str(type_))
|
||||
assert(0)
|
||||
else:
|
||||
m = type_split_re.match(type_)
|
||||
if not m:
|
||||
raise Exception(str(type_))
|
||||
return m.group('type')[0] + m.group('bits')
|
||||
|
||||
template = """\
|
||||
/*
|
||||
@@ -247,7 +249,7 @@ typedef float float32_t;
|
||||
typedef double float64_t;
|
||||
typedef bool bool32_t;
|
||||
% for type in ["float", "int", "uint"]:
|
||||
% for width in [32, 64]:
|
||||
% for width in type_sizes(type):
|
||||
struct ${type}${width}_vec {
|
||||
${type}${width}_t x;
|
||||
${type}${width}_t y;
|
||||
@@ -272,7 +274,7 @@ evaluate_${name}(MAYBE_UNUSED unsigned num_components, unsigned bit_size,
|
||||
nir_const_value _dst_val = { {0, } };
|
||||
|
||||
switch (bit_size) {
|
||||
% for bit_size in [32, 64]:
|
||||
% for bit_size in op_bit_sizes(op):
|
||||
case ${bit_size}: {
|
||||
<%
|
||||
output_type = type_add_size(op.output_type, bit_size)
|
||||
@@ -406,4 +408,5 @@ from mako.template import Template
|
||||
print Template(template).render(opcodes=opcodes, type_sizes=type_sizes,
|
||||
type_has_size=type_has_size,
|
||||
type_add_size=type_add_size,
|
||||
op_bit_sizes=op_bit_sizes,
|
||||
get_const_field=get_const_field)
|
||||
|
Reference in New Issue
Block a user