nir: Add support for 1-bit data types
This commit adds support for 1-bit Booleans and integers. Booleans obviously take a value of true or false. Because we have to define the semantics of 1-bit signed and unsigned integers, we define uint1_t to take values of 0 and 1 and int1_t to take values of 0 and -1. 1-bit arithmetic is then well-defined in the usual way, just with fewer bits. The definition of int1_t and uint1_t doesn't usually matter but we do need something for purposes of constant folding. Reviewed-by: Eric Anholt <eric@anholt.net> Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl> Tested-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
This commit is contained in:

committed by
Jason Ekstrand

parent
2fe8708ffd
commit
3191a82372
@@ -638,6 +638,7 @@ const_value_int(int64_t i, unsigned bit_size)
|
||||
{
|
||||
nir_const_value v;
|
||||
switch (bit_size) {
|
||||
case 1: v.b[0] = i & 1; break;
|
||||
case 8: v.i8[0] = i; break;
|
||||
case 16: v.i16[0] = i; break;
|
||||
case 32: v.i32[0] = i; break;
|
||||
@@ -1206,6 +1207,8 @@ nir_src_comp_as_int(nir_src src, unsigned comp)
|
||||
|
||||
assert(comp < load->def.num_components);
|
||||
switch (load->def.bit_size) {
|
||||
/* int1_t uses 0/-1 convention */
|
||||
case 1: return -(int)load->value.b[comp];
|
||||
case 8: return load->value.i8[comp];
|
||||
case 16: return load->value.i16[comp];
|
||||
case 32: return load->value.i32[comp];
|
||||
@@ -1223,6 +1226,7 @@ nir_src_comp_as_uint(nir_src src, unsigned comp)
|
||||
|
||||
assert(comp < load->def.num_components);
|
||||
switch (load->def.bit_size) {
|
||||
case 1: return load->value.b[comp];
|
||||
case 8: return load->value.u8[comp];
|
||||
case 16: return load->value.u16[comp];
|
||||
case 32: return load->value.u32[comp];
|
||||
@@ -1235,15 +1239,12 @@ nir_src_comp_as_uint(nir_src src, unsigned comp)
|
||||
bool
|
||||
nir_src_comp_as_bool(nir_src src, unsigned comp)
|
||||
{
|
||||
assert(nir_src_is_const(src));
|
||||
nir_load_const_instr *load = nir_instr_as_load_const(src.ssa->parent_instr);
|
||||
int64_t i = nir_src_comp_as_int(src, comp);
|
||||
|
||||
assert(comp < load->def.num_components);
|
||||
assert(load->def.bit_size == 32);
|
||||
assert(load->value.u32[comp] == NIR_TRUE ||
|
||||
load->value.u32[comp] == NIR_FALSE);
|
||||
/* Booleans of any size use 0/-1 convention */
|
||||
assert(i == 0 || i == -1);
|
||||
|
||||
return load->value.u32[comp];
|
||||
return i;
|
||||
}
|
||||
|
||||
double
|
||||
|
@@ -118,6 +118,7 @@ typedef enum {
|
||||
} nir_rounding_mode;
|
||||
|
||||
typedef union {
|
||||
bool b[NIR_MAX_VEC_COMPONENTS];
|
||||
float f32[NIR_MAX_VEC_COMPONENTS];
|
||||
double f64[NIR_MAX_VEC_COMPONENTS];
|
||||
int8_t i8[NIR_MAX_VEC_COMPONENTS];
|
||||
@@ -779,17 +780,25 @@ typedef struct {
|
||||
unsigned write_mask : NIR_MAX_VEC_COMPONENTS; /* ignored if dest.is_ssa is true */
|
||||
} nir_alu_dest;
|
||||
|
||||
/** NIR sized and unsized types
|
||||
*
|
||||
* The values in this enum are carefully chosen so that the sized type is
|
||||
* just the unsized type OR the number of bits.
|
||||
*/
|
||||
typedef enum {
|
||||
nir_type_invalid = 0, /* Not a valid type */
|
||||
nir_type_float,
|
||||
nir_type_int,
|
||||
nir_type_uint,
|
||||
nir_type_bool,
|
||||
nir_type_int = 2,
|
||||
nir_type_uint = 4,
|
||||
nir_type_bool = 6,
|
||||
nir_type_float = 128,
|
||||
nir_type_bool1 = 1 | nir_type_bool,
|
||||
nir_type_bool32 = 32 | nir_type_bool,
|
||||
nir_type_int1 = 1 | nir_type_int,
|
||||
nir_type_int8 = 8 | nir_type_int,
|
||||
nir_type_int16 = 16 | nir_type_int,
|
||||
nir_type_int32 = 32 | nir_type_int,
|
||||
nir_type_int64 = 64 | nir_type_int,
|
||||
nir_type_uint1 = 1 | nir_type_uint,
|
||||
nir_type_uint8 = 8 | nir_type_uint,
|
||||
nir_type_uint16 = 16 | nir_type_uint,
|
||||
nir_type_uint32 = 32 | nir_type_uint,
|
||||
@@ -799,8 +808,8 @@ typedef enum {
|
||||
nir_type_float64 = 64 | nir_type_float,
|
||||
} nir_alu_type;
|
||||
|
||||
#define NIR_ALU_TYPE_SIZE_MASK 0xfffffff8
|
||||
#define NIR_ALU_TYPE_BASE_TYPE_MASK 0x00000007
|
||||
#define NIR_ALU_TYPE_SIZE_MASK 0x79
|
||||
#define NIR_ALU_TYPE_BASE_TYPE_MASK 0x86
|
||||
|
||||
static inline unsigned
|
||||
nir_alu_type_get_type_size(nir_alu_type type)
|
||||
|
@@ -332,7 +332,10 @@ nir_imm_intN_t(nir_builder *build, uint64_t x, unsigned bit_size)
|
||||
|
||||
memset(&v, 0, sizeof(v));
|
||||
assert(bit_size <= 64);
|
||||
v.i64[0] = x & (~0ull >> (64 - bit_size));
|
||||
if (bit_size == 1)
|
||||
v.b[0] = x & 1;
|
||||
else
|
||||
v.i64[0] = x & (~0ull >> (64 - bit_size));
|
||||
|
||||
return nir_build_imm(build, 1, bit_size, v);
|
||||
}
|
||||
@@ -351,6 +354,13 @@ nir_imm_ivec4(nir_builder *build, int x, int y, int z, int w)
|
||||
return nir_build_imm(build, 4, 32, v);
|
||||
}
|
||||
|
||||
static inline nir_ssa_def *
|
||||
nir_imm_boolN_t(nir_builder *build, bool x, unsigned bit_size)
|
||||
{
|
||||
/* We use a 0/-1 convention for all booleans regardless of size */
|
||||
return nir_imm_intN_t(build, -(int)x, bit_size);
|
||||
}
|
||||
|
||||
static inline nir_ssa_def *
|
||||
nir_build_alu(nir_builder *build, nir_op op, nir_ssa_def *src0,
|
||||
nir_ssa_def *src1, nir_ssa_def *src2, nir_ssa_def *src3)
|
||||
|
@@ -24,7 +24,9 @@ def op_bit_sizes(op):
|
||||
return sorted(list(sizes)) if sizes is not None else None
|
||||
|
||||
def get_const_field(type_):
|
||||
if type_base_type(type_) == 'bool':
|
||||
if type_size(type_) == 1:
|
||||
return 'b'
|
||||
elif type_base_type(type_) == 'bool':
|
||||
return 'i' + str(type_size(type_))
|
||||
elif type_ == "float16":
|
||||
return "u16"
|
||||
@@ -237,9 +239,12 @@ unpack_half_1x16(uint16_t u)
|
||||
}
|
||||
|
||||
/* Some typed vector structures to make things like src0.y work */
|
||||
typedef int8_t int1_t;
|
||||
typedef uint8_t uint1_t;
|
||||
typedef float float16_t;
|
||||
typedef float float32_t;
|
||||
typedef double float64_t;
|
||||
typedef bool bool1_t;
|
||||
typedef bool bool8_t;
|
||||
typedef bool bool16_t;
|
||||
typedef bool bool32_t;
|
||||
@@ -274,7 +279,10 @@ struct ${type}${width}_vec {
|
||||
|
||||
const struct ${input_types[j]}_vec src${j} = {
|
||||
% for k in range(op.input_sizes[j]):
|
||||
% if input_types[j] == "float16":
|
||||
% if input_types[j] == "int1":
|
||||
/* 1-bit integers use a 0/-1 convention */
|
||||
-(int1_t)_src[${j}].b[${k}],
|
||||
% elif input_types[j] == "float16":
|
||||
_mesa_half_to_float(_src[${j}].u16[${k}]),
|
||||
% else:
|
||||
_src[${j}].${get_const_field(input_types[j])}[${k}],
|
||||
@@ -299,6 +307,9 @@ struct ${type}${width}_vec {
|
||||
% elif "src" + str(j) not in op.const_expr:
|
||||
## Avoid unused variable warnings
|
||||
<% continue %>
|
||||
% elif input_types[j] == "int1":
|
||||
/* 1-bit integers use a 0/-1 convention */
|
||||
const int1_t src${j} = -(int1_t)_src[${j}].b[_i];
|
||||
% elif input_types[j] == "float16":
|
||||
const float src${j} =
|
||||
_mesa_half_to_float(_src[${j}].u16[_i]);
|
||||
@@ -321,7 +332,10 @@ struct ${type}${width}_vec {
|
||||
|
||||
## Store the current component of the actual destination to the
|
||||
## value of dst.
|
||||
% if output_type.startswith("bool"):
|
||||
% if output_type == "int1" or output_type == "uint1":
|
||||
/* 1-bit integers get truncated */
|
||||
_dst_val.b[_i] = dst & 1;
|
||||
% elif output_type.startswith("bool"):
|
||||
## Sanitize the C value to a proper NIR 0/-1 bool
|
||||
_dst_val.${get_const_field(output_type)}[_i] = -(int)dst;
|
||||
% elif output_type == "float16":
|
||||
@@ -350,7 +364,10 @@ struct ${type}${width}_vec {
|
||||
## For each component in the destination, copy the value of dst to
|
||||
## the actual destination.
|
||||
% for k in range(op.output_size):
|
||||
% if output_type == "bool32":
|
||||
% if output_type == "int1" or output_type == "uint1":
|
||||
/* 1-bit integers get truncated */
|
||||
_dst_val.b[${k}] = dst.${"xyzw"[k]} & 1;
|
||||
% elif output_type.startswith("bool"):
|
||||
## Sanitize the C value to a proper NIR 0/-1 bool
|
||||
_dst_val.${get_const_field(output_type)}[${k}] = -(int)dst.${"xyzw"[k]};
|
||||
% elif output_type == "float16":
|
||||
|
@@ -117,8 +117,15 @@ hash_load_const(uint32_t hash, const nir_load_const_instr *instr)
|
||||
{
|
||||
hash = HASH(hash, instr->def.num_components);
|
||||
|
||||
unsigned size = instr->def.num_components * (instr->def.bit_size / 8);
|
||||
hash = _mesa_fnv32_1a_accumulate_block(hash, instr->value.f32, size);
|
||||
if (instr->def.bit_size == 1) {
|
||||
for (unsigned i = 0; i < instr->def.num_components; i++) {
|
||||
uint8_t b = instr->value.b[i];
|
||||
hash = HASH(hash, b);
|
||||
}
|
||||
} else {
|
||||
unsigned size = instr->def.num_components * (instr->def.bit_size / 8);
|
||||
hash = _mesa_fnv32_1a_accumulate_block(hash, instr->value.f32, size);
|
||||
}
|
||||
|
||||
return hash;
|
||||
}
|
||||
@@ -399,8 +406,13 @@ nir_instrs_equal(const nir_instr *instr1, const nir_instr *instr2)
|
||||
if (load1->def.bit_size != load2->def.bit_size)
|
||||
return false;
|
||||
|
||||
return memcmp(load1->value.f32, load2->value.f32,
|
||||
load1->def.num_components * (load1->def.bit_size / 8u)) == 0;
|
||||
if (load1->def.bit_size == 1) {
|
||||
unsigned size = load1->def.num_components * sizeof(bool);
|
||||
return memcmp(load1->value.b, load2->value.b, size) == 0;
|
||||
} else {
|
||||
unsigned size = load1->def.num_components * (load1->def.bit_size / 8);
|
||||
return memcmp(load1->value.f32, load2->value.f32, size) == 0;
|
||||
}
|
||||
}
|
||||
case nir_instr_type_phi: {
|
||||
nir_phi_instr *phi1 = nir_instr_as_phi(instr1);
|
||||
|
@@ -63,6 +63,9 @@ lower_load_const_instr_scalar(nir_load_const_instr *lower)
|
||||
case 8:
|
||||
load_comp->value.u8[0] = lower->value.u8[i];
|
||||
break;
|
||||
case 1:
|
||||
load_comp->value.b[0] = lower->value.b[i];
|
||||
break;
|
||||
default:
|
||||
assert(!"invalid bit size");
|
||||
}
|
||||
|
@@ -88,6 +88,9 @@ constant_fold_alu_instr(nir_alu_instr *instr, void *mem_ctx)
|
||||
case 8:
|
||||
src[i].u8[j] = load_const->value.u8[instr->src[i].swizzle[j]];
|
||||
break;
|
||||
case 1:
|
||||
src[i].b[j] = load_const->value.b[instr->src[i].swizzle[j]];
|
||||
break;
|
||||
default:
|
||||
unreachable("Invalid bit size");
|
||||
}
|
||||
|
@@ -996,6 +996,9 @@ print_load_const_instr(nir_load_const_instr *instr, print_state *state)
|
||||
case 8:
|
||||
fprintf(fp, "0x%02x", instr->value.u8[i]);
|
||||
break;
|
||||
case 1:
|
||||
fprintf(fp, "%s", instr->value.b[i] ? "true" : "false");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -476,8 +476,9 @@ construct_value(nir_builder *build,
|
||||
break;
|
||||
|
||||
case nir_type_bool:
|
||||
cval = nir_imm_bool(build, c->data.u);
|
||||
cval = nir_imm_boolN_t(build, c->data.u, bit_size);
|
||||
break;
|
||||
|
||||
default:
|
||||
unreachable("Invalid alu source type");
|
||||
}
|
||||
|
@@ -818,7 +818,7 @@ validate_if(nir_if *if_stmt, validate_state *state)
|
||||
nir_cf_node *next_node = nir_cf_node_next(&if_stmt->cf_node);
|
||||
validate_assert(state, next_node->type == nir_cf_node_block);
|
||||
|
||||
validate_src(&if_stmt->condition, state, 32, 1);
|
||||
validate_src(&if_stmt->condition, state, 0, 1);
|
||||
|
||||
validate_assert(state, !exec_list_is_empty(&if_stmt->then_list));
|
||||
validate_assert(state, !exec_list_is_empty(&if_stmt->else_list));
|
||||
|
@@ -1561,6 +1561,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
||||
case 8:
|
||||
val->constant->values[0].u8[i] = elems[i]->values[0].u8[0];
|
||||
break;
|
||||
case 1:
|
||||
val->constant->values[0].b[i] = elems[i]->values[0].b[0];
|
||||
break;
|
||||
default:
|
||||
vtn_fail("Invalid SpvOpConstantComposite bit size");
|
||||
}
|
||||
@@ -1734,6 +1737,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
||||
case 8:
|
||||
val->constant->values[0].u8[i] = (*c)->values[col].u8[elem + i];
|
||||
break;
|
||||
case 1:
|
||||
val->constant->values[0].b[i] = (*c)->values[col].b[elem + i];
|
||||
break;
|
||||
default:
|
||||
vtn_fail("Invalid SpvOpCompositeExtract bit size");
|
||||
}
|
||||
@@ -1761,6 +1767,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
||||
case 8:
|
||||
(*c)->values[col].u8[elem + i] = insert->constant->values[0].u8[i];
|
||||
break;
|
||||
case 1:
|
||||
(*c)->values[col].b[elem + i] = insert->constant->values[0].b[i];
|
||||
break;
|
||||
default:
|
||||
vtn_fail("Invalid SpvOpCompositeInsert bit size");
|
||||
}
|
||||
|
Reference in New Issue
Block a user