nir: Pass fully qualified type to nir_const_value_negative_equal

Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Suggested-by: Jason Ekstrand <jason@jlekstrand.net>
Reviewed-by: Matt Turner <mattst88@gmail.com>
This commit is contained in:
Ian Romanick
2019-06-13 12:59:29 -07:00
parent 0ac5ff9ecb
commit ec96c289ea
3 changed files with 181 additions and 229 deletions

View File

@@ -1032,8 +1032,7 @@ nir_ssa_alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
bool nir_const_value_negative_equal(const nir_const_value *c1, bool nir_const_value_negative_equal(const nir_const_value *c1,
const nir_const_value *c2, const nir_const_value *c2,
unsigned components, unsigned components,
nir_alu_type base_type, nir_alu_type full_type);
unsigned bits);
bool nir_alu_srcs_equal(const nir_alu_instr *alu1, const nir_alu_instr *alu2, bool nir_alu_srcs_equal(const nir_alu_instr *alu1, const nir_alu_instr *alu2,
unsigned src1, unsigned src2); unsigned src1, unsigned src2);

View File

@@ -305,95 +305,73 @@ bool
nir_const_value_negative_equal(const nir_const_value *c1, nir_const_value_negative_equal(const nir_const_value *c1,
const nir_const_value *c2, const nir_const_value *c2,
unsigned components, unsigned components,
nir_alu_type base_type, nir_alu_type full_type)
unsigned bits)
{ {
assert(base_type == nir_alu_type_get_base_type(base_type)); assert(nir_alu_type_get_base_type(full_type) != nir_type_invalid);
assert(base_type != nir_type_invalid); assert(nir_alu_type_get_type_size(full_type) != 0);
/* This can occur for 1-bit Boolean values. */ switch (full_type) {
if (bits == 1) case nir_type_float16:
return false; for (unsigned i = 0; i < components; i++) {
if (_mesa_half_to_float(c1[i].u16) !=
switch (base_type) { -_mesa_half_to_float(c2[i].u16)) {
case nir_type_float: return false;
switch (bits) {
case 16:
for (unsigned i = 0; i < components; i++) {
if (_mesa_half_to_float(c1[i].u16) !=
-_mesa_half_to_float(c2[i].u16)) {
return false;
}
} }
return true;
case 32:
for (unsigned i = 0; i < components; i++) {
if (c1[i].f32 != -c2[i].f32)
return false;
}
return true;
case 64:
for (unsigned i = 0; i < components; i++) {
if (c1[i].f64 != -c2[i].f64)
return false;
}
return true;
default:
unreachable("unknown bit size");
} }
break; return true;
case nir_type_int: case nir_type_float32:
case nir_type_uint: for (unsigned i = 0; i < components; i++) {
switch (bits) { if (c1[i].f32 != -c2[i].f32)
case 8: return false;
for (unsigned i = 0; i < components; i++) {
if (c1[i].i8 != -c2[i].i8)
return false;
}
return true;
case 16:
for (unsigned i = 0; i < components; i++) {
if (c1[i].i16 != -c2[i].i16)
return false;
}
return true;
break;
case 32:
for (unsigned i = 0; i < components; i++) {
if (c1[i].i32 != -c2[i].i32)
return false;
}
return true;
case 64:
for (unsigned i = 0; i < components; i++) {
if (c1[i].i64 != -c2[i].i64)
return false;
}
return true;
default:
unreachable("unknown bit size");
} }
break; return true;
case nir_type_bool: case nir_type_float64:
return false; for (unsigned i = 0; i < components; i++) {
if (c1[i].f64 != -c2[i].f64)
return false;
}
return true;
case nir_type_int8:
case nir_type_uint8:
for (unsigned i = 0; i < components; i++) {
if (c1[i].i8 != -c2[i].i8)
return false;
}
return true;
case nir_type_int16:
case nir_type_uint16:
for (unsigned i = 0; i < components; i++) {
if (c1[i].i16 != -c2[i].i16)
return false;
}
return true;
case nir_type_int32:
case nir_type_uint32:
for (unsigned i = 0; i < components; i++) {
if (c1[i].i32 != -c2[i].i32)
return false;
}
return true;
case nir_type_int64:
case nir_type_uint64:
for (unsigned i = 0; i < components; i++) {
if (c1[i].i64 != -c2[i].i64)
return false;
}
return true;
default: default:
break; break;
@@ -449,7 +427,7 @@ nir_alu_srcs_negative_equal(const nir_alu_instr *alu1,
return nir_const_value_negative_equal(const1, return nir_const_value_negative_equal(const1,
const2, const2,
nir_ssa_alu_instr_src_components(alu1, src1), nir_ssa_alu_instr_src_components(alu1, src1),
nir_op_infos[alu1->op].input_types[src1], nir_op_infos[alu1->op].input_types[src1] |
nir_src_bit_size(alu1->src[src1].src)); nir_src_bit_size(alu1->src[src1].src));
} }

View File

@@ -26,10 +26,10 @@
#include "util/half_float.h" #include "util/half_float.h"
static void count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS], static void count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS],
nir_alu_type base_type, unsigned bits, int first); nir_alu_type full_type, int first);
static void negate(nir_const_value dst[NIR_MAX_VEC_COMPONENTS], static void negate(nir_const_value dst[NIR_MAX_VEC_COMPONENTS],
const nir_const_value src[NIR_MAX_VEC_COMPONENTS], const nir_const_value src[NIR_MAX_VEC_COMPONENTS],
nir_alu_type base_type, unsigned bits, unsigned components); nir_alu_type full_type, unsigned components);
class const_value_negative_equal_test : public ::testing::Test { class const_value_negative_equal_test : public ::testing::Test {
protected: protected:
@@ -68,89 +68,89 @@ TEST_F(const_value_negative_equal_test, float32_zero)
{ {
/* Verify that 0.0 negative-equals 0.0. */ /* Verify that 0.0 negative-equals 0.0. */
EXPECT_TRUE(nir_const_value_negative_equal(c1, c1, NIR_MAX_VEC_COMPONENTS, EXPECT_TRUE(nir_const_value_negative_equal(c1, c1, NIR_MAX_VEC_COMPONENTS,
nir_type_float, 32)); nir_type_float32));
} }
TEST_F(const_value_negative_equal_test, float64_zero) TEST_F(const_value_negative_equal_test, float64_zero)
{ {
/* Verify that 0.0 negative-equals 0.0. */ /* Verify that 0.0 negative-equals 0.0. */
EXPECT_TRUE(nir_const_value_negative_equal(c1, c1, NIR_MAX_VEC_COMPONENTS, EXPECT_TRUE(nir_const_value_negative_equal(c1, c1, NIR_MAX_VEC_COMPONENTS,
nir_type_float, 64)); nir_type_float64));
} }
/* Compare an object with non-zero values to itself. This should always be /* Compare an object with non-zero values to itself. This should always be
* false. * false.
*/ */
#define compare_with_self(base_type, bits) \ #define compare_with_self(full_type) \
TEST_F(const_value_negative_equal_test, base_type ## bits ## _self) \ TEST_F(const_value_negative_equal_test, full_type ## _self) \
{ \ { \
count_sequence(c1, base_type, bits, 1); \ count_sequence(c1, full_type, 1); \
EXPECT_FALSE(nir_const_value_negative_equal(c1, c1, \ EXPECT_FALSE(nir_const_value_negative_equal(c1, c1, \
NIR_MAX_VEC_COMPONENTS, \ NIR_MAX_VEC_COMPONENTS, \
base_type, bits)); \ full_type)); \
} }
compare_with_self(nir_type_float, 16) compare_with_self(nir_type_float16)
compare_with_self(nir_type_float, 32) compare_with_self(nir_type_float32)
compare_with_self(nir_type_float, 64) compare_with_self(nir_type_float64)
compare_with_self(nir_type_int, 8) compare_with_self(nir_type_int8)
compare_with_self(nir_type_uint, 8) compare_with_self(nir_type_uint8)
compare_with_self(nir_type_int, 16) compare_with_self(nir_type_int16)
compare_with_self(nir_type_uint, 16) compare_with_self(nir_type_uint16)
compare_with_self(nir_type_int, 32) compare_with_self(nir_type_int32)
compare_with_self(nir_type_uint, 32) compare_with_self(nir_type_uint32)
compare_with_self(nir_type_int, 64) compare_with_self(nir_type_int64)
compare_with_self(nir_type_uint, 64) compare_with_self(nir_type_uint64)
/* Compare an object with the negation of itself. This should always be true. /* Compare an object with the negation of itself. This should always be true.
*/ */
#define compare_with_negation(base_type, bits) \ #define compare_with_negation(full_type) \
TEST_F(const_value_negative_equal_test, base_type ## bits ## _trivially_true) \ TEST_F(const_value_negative_equal_test, full_type ## _trivially_true) \
{ \ { \
count_sequence(c1, base_type, bits, 1); \ count_sequence(c1, full_type, 1); \
negate(c2, c1, base_type, bits, NIR_MAX_VEC_COMPONENTS); \ negate(c2, c1, full_type, NIR_MAX_VEC_COMPONENTS); \
EXPECT_TRUE(nir_const_value_negative_equal(c1, c2, \ EXPECT_TRUE(nir_const_value_negative_equal(c1, c2, \
NIR_MAX_VEC_COMPONENTS, \ NIR_MAX_VEC_COMPONENTS, \
base_type, bits)); \ full_type)); \
} }
compare_with_negation(nir_type_float, 16) compare_with_negation(nir_type_float16)
compare_with_negation(nir_type_float, 32) compare_with_negation(nir_type_float32)
compare_with_negation(nir_type_float, 64) compare_with_negation(nir_type_float64)
compare_with_negation(nir_type_int, 8) compare_with_negation(nir_type_int8)
compare_with_negation(nir_type_uint, 8) compare_with_negation(nir_type_uint8)
compare_with_negation(nir_type_int, 16) compare_with_negation(nir_type_int16)
compare_with_negation(nir_type_uint, 16) compare_with_negation(nir_type_uint16)
compare_with_negation(nir_type_int, 32) compare_with_negation(nir_type_int32)
compare_with_negation(nir_type_uint, 32) compare_with_negation(nir_type_uint32)
compare_with_negation(nir_type_int, 64) compare_with_negation(nir_type_int64)
compare_with_negation(nir_type_uint, 64) compare_with_negation(nir_type_uint64)
/* Compare fewer than the maximum possible components. All of the components /* Compare fewer than the maximum possible components. All of the components
* that are compared a negative-equal, but the extra components are not. * that are compared a negative-equal, but the extra components are not.
*/ */
#define compare_fewer_components(base_type, bits) \ #define compare_fewer_components(full_type) \
TEST_F(const_value_negative_equal_test, base_type ## bits ## _fewer_components) \ TEST_F(const_value_negative_equal_test, full_type ## _fewer_components) \
{ \ { \
count_sequence(c1, base_type, bits, 1); \ count_sequence(c1, full_type, 1); \
negate(c2, c1, base_type, bits, 3); \ negate(c2, c1, full_type, 3); \
EXPECT_TRUE(nir_const_value_negative_equal(c1, c2, 3, base_type, bits)); \ EXPECT_TRUE(nir_const_value_negative_equal(c1, c2, 3, full_type)); \
EXPECT_FALSE(nir_const_value_negative_equal(c1, c2, \ EXPECT_FALSE(nir_const_value_negative_equal(c1, c2, \
NIR_MAX_VEC_COMPONENTS, \ NIR_MAX_VEC_COMPONENTS, \
base_type, bits)); \ full_type)); \
} }
compare_fewer_components(nir_type_float, 16) compare_fewer_components(nir_type_float16)
compare_fewer_components(nir_type_float, 32) compare_fewer_components(nir_type_float32)
compare_fewer_components(nir_type_float, 64) compare_fewer_components(nir_type_float64)
compare_fewer_components(nir_type_int, 8) compare_fewer_components(nir_type_int8)
compare_fewer_components(nir_type_uint, 8) compare_fewer_components(nir_type_uint8)
compare_fewer_components(nir_type_int, 16) compare_fewer_components(nir_type_int16)
compare_fewer_components(nir_type_uint, 16) compare_fewer_components(nir_type_uint16)
compare_fewer_components(nir_type_int, 32) compare_fewer_components(nir_type_int32)
compare_fewer_components(nir_type_uint, 32) compare_fewer_components(nir_type_uint32)
compare_fewer_components(nir_type_int, 64) compare_fewer_components(nir_type_int64)
compare_fewer_components(nir_type_uint, 64) compare_fewer_components(nir_type_uint64)
TEST_F(alu_srcs_negative_equal_test, trivial_float) TEST_F(alu_srcs_negative_equal_test, trivial_float)
{ {
@@ -221,65 +221,53 @@ TEST_F(alu_srcs_negative_equal_test, trivial_negation_int)
} }
static void static void
count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS], nir_alu_type base_type, unsigned bits, int first) count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS],
nir_alu_type full_type, int first)
{ {
switch (base_type) { switch (full_type) {
case nir_type_float: case nir_type_float16:
switch (bits) { for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
case 16: c[i].u16 = _mesa_float_to_half(float(i + first));
for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
c[i].u16 = _mesa_float_to_half(float(i + first));
break;
case 32:
for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
c[i].f32 = float(i + first);
break;
case 64:
for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
c[i].f64 = double(i + first);
break;
default:
unreachable("unknown bit size");
}
break; break;
case nir_type_int: case nir_type_float32:
case nir_type_uint: for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
switch (bits) { c[i].f32 = float(i + first);
case 8:
for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
c[i].i8 = i + first;
break; break;
case 16: case nir_type_float64:
for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
c[i].i16 = i + first; c[i].f64 = double(i + first);
break; break;
case 32: case nir_type_int8:
for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) case nir_type_uint8:
c[i].i32 = i + first; for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
c[i].i8 = i + first;
break; break;
case 64: case nir_type_int16:
for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) case nir_type_uint16:
c[i].i64 = i + first; for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
c[i].i16 = i + first;
break; break;
default: case nir_type_int32:
unreachable("unknown bit size"); case nir_type_uint32:
} for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
c[i].i32 = i + first;
break;
case nir_type_int64:
case nir_type_uint64:
for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
c[i].i64 = i + first;
break; break;
@@ -292,65 +280,52 @@ count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS], nir_alu_type base_type
static void static void
negate(nir_const_value dst[NIR_MAX_VEC_COMPONENTS], negate(nir_const_value dst[NIR_MAX_VEC_COMPONENTS],
const nir_const_value src[NIR_MAX_VEC_COMPONENTS], const nir_const_value src[NIR_MAX_VEC_COMPONENTS],
nir_alu_type base_type, unsigned bits, unsigned components) nir_alu_type full_type, unsigned components)
{ {
switch (base_type) { switch (full_type) {
case nir_type_float: case nir_type_float16:
switch (bits) { for (unsigned i = 0; i < components; i++)
case 16: dst[i].u16 = _mesa_float_to_half(-_mesa_half_to_float(src[i].u16));
for (unsigned i = 0; i < components; i++)
dst[i].u16 = _mesa_float_to_half(-_mesa_half_to_float(src[i].u16));
break;
case 32:
for (unsigned i = 0; i < components; i++)
dst[i].f32 = -src[i].f32;
break;
case 64:
for (unsigned i = 0; i < components; i++)
dst[i].f64 = -src[i].f64;
break;
default:
unreachable("unknown bit size");
}
break; break;
case nir_type_int: case nir_type_float32:
case nir_type_uint: for (unsigned i = 0; i < components; i++)
switch (bits) { dst[i].f32 = -src[i].f32;
case 8:
for (unsigned i = 0; i < components; i++)
dst[i].i8 = -src[i].i8;
break; break;
case 16: case nir_type_float64:
for (unsigned i = 0; i < components; i++) for (unsigned i = 0; i < components; i++)
dst[i].i16 = -src[i].i16; dst[i].f64 = -src[i].f64;
break; break;
case 32: case nir_type_int8:
for (unsigned i = 0; i < components; i++) case nir_type_uint8:
dst[i].i32 = -src[i].i32; for (unsigned i = 0; i < components; i++)
dst[i].i8 = -src[i].i8;
break; break;
case 64: case nir_type_int16:
for (unsigned i = 0; i < components; i++) case nir_type_uint16:
dst[i].i64 = -src[i].i64; for (unsigned i = 0; i < components; i++)
dst[i].i16 = -src[i].i16;
break; break;
default: case nir_type_int32:
unreachable("unknown bit size"); case nir_type_uint32:
} for (unsigned i = 0; i < components; i++)
dst[i].i32 = -src[i].i32;
break;
case nir_type_int64:
case nir_type_uint64:
for (unsigned i = 0; i < components; i++)
dst[i].i64 = -src[i].i64;
break; break;