nir+vtn: vec8+vec16 support
This introduces new vec8 and vec16 instructions (which are the only instructions taking more than 4 sources), in order to construct 8 and 16 component vectors. In order to avoid fixing up the non-autogenerated nir_build_alu() sites and making them pass 16 src args for the benefit of the two instructions that take more than 4 srcs (ie vec8 and vec16), nir_build_alu() is has nir_build_alu_tail() split out and re-used by nir_build_alu2() (which is used for the > 4 src args case). v2 (Karol Herbst): use nir_build_alu2 for vec8 and vec16 use python's array multiplication syntax add nir_op_vec helper simplify nir_vec nir_build_alu_tail -> nir_builder_alu_instr_finish_and_insert use nir_build_alu for opcodes with <= 4 sources v3 (Karol Herbst): fix nir_serialize v4 (Dave Airlie): fix serialization of glsl_type handle vec8/16 in lowering of bools v5 (Karol Herbst): fix load store vectorizer Signed-off-by: Karol Herbst <kherbst@redhat.com> Reviewed-by: Dave Airlie <airlied@redhat.com>
This commit is contained in:
@@ -2630,9 +2630,13 @@ encode_type_to_blob(struct blob *blob, const glsl_type *type)
|
||||
case GLSL_TYPE_INT64:
|
||||
case GLSL_TYPE_BOOL:
|
||||
encoded.basic.interface_row_major = type->interface_row_major;
|
||||
assert(type->vector_elements < 8);
|
||||
assert(type->matrix_columns < 8);
|
||||
encoded.basic.vector_elements = type->vector_elements;
|
||||
if (type->vector_elements <= 4)
|
||||
encoded.basic.vector_elements = type->vector_elements;
|
||||
else if (type->vector_elements == 8)
|
||||
encoded.basic.vector_elements = 5;
|
||||
else if (type->vector_elements == 16)
|
||||
encoded.basic.vector_elements = 6;
|
||||
encoded.basic.matrix_columns = type->matrix_columns;
|
||||
encoded.basic.explicit_stride = MIN2(type->explicit_stride, 0xfffff);
|
||||
blob_write_uint32(blob, encoded.u32);
|
||||
@@ -2741,6 +2745,11 @@ decode_type_from_blob(struct blob_reader *blob)
|
||||
unsigned explicit_stride = encoded.basic.explicit_stride;
|
||||
if (explicit_stride == 0xfffff)
|
||||
explicit_stride = blob_read_uint32(blob);
|
||||
uint32_t vector_elements = encoded.basic.vector_elements;
|
||||
if (vector_elements == 5)
|
||||
vector_elements = 8;
|
||||
else if (vector_elements == 6)
|
||||
vector_elements = 16;
|
||||
return glsl_type::get_instance(base_type, encoded.basic.vector_elements,
|
||||
encoded.basic.matrix_columns,
|
||||
explicit_stride,
|
||||
|
@@ -58,10 +58,19 @@ extern "C" {
|
||||
|
||||
#define NIR_FALSE 0u
|
||||
#define NIR_TRUE (~0u)
|
||||
#define NIR_MAX_VEC_COMPONENTS 4
|
||||
#define NIR_MAX_VEC_COMPONENTS 16
|
||||
#define NIR_MAX_MATRIX_COLUMNS 4
|
||||
#define NIR_STREAM_PACKED (1 << 8)
|
||||
typedef uint8_t nir_component_mask_t;
|
||||
typedef uint16_t nir_component_mask_t;
|
||||
|
||||
static inline bool
|
||||
nir_num_components_valid(unsigned num_components)
|
||||
{
|
||||
return (num_components >= 1 &&
|
||||
num_components <= 4) ||
|
||||
num_components == 8 ||
|
||||
num_components == 16;
|
||||
}
|
||||
|
||||
/** Defines a cast function
|
||||
*
|
||||
@@ -1030,6 +1039,8 @@ nir_op_vec(unsigned components)
|
||||
case 2: return nir_op_vec2;
|
||||
case 3: return nir_op_vec3;
|
||||
case 4: return nir_op_vec4;
|
||||
case 8: return nir_op_vec8;
|
||||
case 16: return nir_op_vec16;
|
||||
default: unreachable("bad component count");
|
||||
}
|
||||
}
|
||||
|
@@ -874,7 +874,7 @@ nir_ssa_for_src(nir_builder *build, nir_src src, int num_components)
|
||||
static inline nir_ssa_def *
|
||||
nir_ssa_for_alu_src(nir_builder *build, nir_alu_instr *instr, unsigned srcn)
|
||||
{
|
||||
static uint8_t trivial_swizzle[] = { 0, 1, 2, 3 };
|
||||
static uint8_t trivial_swizzle[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
|
||||
STATIC_ASSERT(ARRAY_SIZE(trivial_swizzle) == NIR_MAX_VEC_COMPONENTS);
|
||||
|
||||
nir_alu_src *src = &instr->src[srcn];
|
||||
|
@@ -31,14 +31,22 @@ def src_decl_list(num_srcs):
|
||||
return ', '.join('nir_ssa_def *src' + str(i) for i in range(num_srcs))
|
||||
|
||||
def src_list(num_srcs):
|
||||
return ', '.join('src' + str(i) if i < num_srcs else 'NULL' for i in range(4))
|
||||
if num_srcs <= 4:
|
||||
return ', '.join('src' + str(i) if i < num_srcs else 'NULL' for i in range(4))
|
||||
else:
|
||||
return ', '.join('src' + str(i) for i in range(num_srcs))
|
||||
%>
|
||||
|
||||
% for name, opcode in sorted(opcodes.items()):
|
||||
static inline nir_ssa_def *
|
||||
nir_${name}(nir_builder *build, ${src_decl_list(opcode.num_inputs)})
|
||||
{
|
||||
% if opcode.num_inputs <= 4:
|
||||
return nir_build_alu(build, nir_op_${name}, ${src_list(opcode.num_inputs)});
|
||||
% else:
|
||||
nir_ssa_def *srcs[${opcode.num_inputs}] = {${src_list(opcode.num_inputs)}};
|
||||
return nir_build_alu_src_arr(build, nir_op_${name}, srcs);
|
||||
% endif
|
||||
}
|
||||
% endfor
|
||||
|
||||
|
@@ -292,6 +292,18 @@ struct ${type}${width}_vec {
|
||||
${type}${width}_t y;
|
||||
${type}${width}_t z;
|
||||
${type}${width}_t w;
|
||||
${type}${width}_t e;
|
||||
${type}${width}_t f;
|
||||
${type}${width}_t g;
|
||||
${type}${width}_t h;
|
||||
${type}${width}_t i;
|
||||
${type}${width}_t j;
|
||||
${type}${width}_t k;
|
||||
${type}${width}_t l;
|
||||
${type}${width}_t m;
|
||||
${type}${width}_t n;
|
||||
${type}${width}_t o;
|
||||
${type}${width}_t p;
|
||||
};
|
||||
% endfor
|
||||
% endfor
|
||||
@@ -324,7 +336,7 @@ struct ${type}${width}_vec {
|
||||
_src[${j}][${k}].${get_const_field(input_types[j])},
|
||||
% endif
|
||||
% endfor
|
||||
% for k in range(op.input_sizes[j], 4):
|
||||
% for k in range(op.input_sizes[j], 16):
|
||||
0,
|
||||
% endfor
|
||||
};
|
||||
@@ -418,18 +430,18 @@ struct ${type}${width}_vec {
|
||||
% for k in range(op.output_size):
|
||||
% if output_type == "int1" or output_type == "uint1":
|
||||
/* 1-bit integers get truncated */
|
||||
_dst_val[${k}].b = dst.${"xyzw"[k]} & 1;
|
||||
_dst_val[${k}].b = dst.${"xyzwefghijklmnop"[k]} & 1;
|
||||
% elif output_type.startswith("bool"):
|
||||
## Sanitize the C value to a proper NIR 0/-1 bool
|
||||
_dst_val[${k}].${get_const_field(output_type)} = -(int)dst.${"xyzw"[k]};
|
||||
_dst_val[${k}].${get_const_field(output_type)} = -(int)dst.${"xyzwefghijklmnop"[k]};
|
||||
% elif output_type == "float16":
|
||||
if (nir_is_rounding_mode_rtz(execution_mode, 16)) {
|
||||
_dst_val[${k}].u16 = _mesa_float_to_float16_rtz(dst.${"xyzw"[k]});
|
||||
_dst_val[${k}].u16 = _mesa_float_to_float16_rtz(dst.${"xyzwefghijklmnop"[k]});
|
||||
} else {
|
||||
_dst_val[${k}].u16 = _mesa_float_to_float16_rtne(dst.${"xyzw"[k]});
|
||||
_dst_val[${k}].u16 = _mesa_float_to_float16_rtne(dst.${"xyzwefghijklmnop"[k]});
|
||||
}
|
||||
% else:
|
||||
_dst_val[${k}].${get_const_field(output_type)} = dst.${"xyzw"[k]};
|
||||
_dst_val[${k}].${get_const_field(output_type)} = dst.${"xyzwefghijklmnop"[k]};
|
||||
% endif
|
||||
|
||||
% if op.name != "fquantize2f16" and type_base_type(output_type) == "float":
|
||||
|
@@ -117,6 +117,8 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
|
||||
return lower_reduction(alu, chan, merge, b); \
|
||||
|
||||
switch (alu->op) {
|
||||
case nir_op_vec16:
|
||||
case nir_op_vec8:
|
||||
case nir_op_vec4:
|
||||
case nir_op_vec3:
|
||||
case nir_op_vec2:
|
||||
|
@@ -56,6 +56,8 @@ lower_alu_instr(nir_builder *b, nir_alu_instr *alu)
|
||||
case nir_op_vec2:
|
||||
case nir_op_vec3:
|
||||
case nir_op_vec4:
|
||||
case nir_op_vec8:
|
||||
case nir_op_vec16:
|
||||
/* These we expect to have booleans but the opcode doesn't change */
|
||||
break;
|
||||
|
||||
|
@@ -53,6 +53,8 @@ lower_alu_instr(nir_alu_instr *alu)
|
||||
case nir_op_vec2:
|
||||
case nir_op_vec3:
|
||||
case nir_op_vec4:
|
||||
case nir_op_vec8:
|
||||
case nir_op_vec16:
|
||||
case nir_op_inot:
|
||||
case nir_op_iand:
|
||||
case nir_op_ior:
|
||||
|
@@ -75,7 +75,7 @@ class Opcode(object):
|
||||
assert isinstance(algebraic_properties, str)
|
||||
assert isinstance(const_expr, str)
|
||||
assert len(input_sizes) == len(input_types)
|
||||
assert 0 <= output_size <= 4
|
||||
assert 0 <= output_size <= 4 or (output_size == 8) or (output_size == 16)
|
||||
for size in input_sizes:
|
||||
assert 0 <= size <= 4
|
||||
if output_size != 0:
|
||||
@@ -1057,6 +1057,40 @@ dst.z = src2.x;
|
||||
dst.w = src3.x;
|
||||
""")
|
||||
|
||||
opcode("vec8", 8, tuint,
|
||||
[1] * 8, [tuint] * 8,
|
||||
False, "", """
|
||||
dst.x = src0.x;
|
||||
dst.y = src1.x;
|
||||
dst.z = src2.x;
|
||||
dst.w = src3.x;
|
||||
dst.e = src4.x;
|
||||
dst.f = src5.x;
|
||||
dst.g = src6.x;
|
||||
dst.h = src7.x;
|
||||
""")
|
||||
|
||||
opcode("vec16", 16, tuint,
|
||||
[1] * 16, [tuint] * 16,
|
||||
False, "", """
|
||||
dst.x = src0.x;
|
||||
dst.y = src1.x;
|
||||
dst.z = src2.x;
|
||||
dst.w = src3.x;
|
||||
dst.e = src4.x;
|
||||
dst.f = src5.x;
|
||||
dst.g = src6.x;
|
||||
dst.h = src7.x;
|
||||
dst.i = src8.x;
|
||||
dst.j = src9.x;
|
||||
dst.k = src10.x;
|
||||
dst.l = src11.x;
|
||||
dst.m = src12.x;
|
||||
dst.n = src13.x;
|
||||
dst.o = src14.x;
|
||||
dst.p = src15.x;
|
||||
""")
|
||||
|
||||
# An integer multiply instruction for address calculation. This is
|
||||
# similar to imul, except that the results are undefined in case of
|
||||
# overflow. Overflow is defined according to the size of the variable
|
||||
|
@@ -643,7 +643,7 @@ new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
|
||||
return false;
|
||||
|
||||
unsigned new_num_components = size / new_bit_size;
|
||||
if (new_num_components > NIR_MAX_VEC_COMPONENTS)
|
||||
if (!nir_num_components_valid(new_num_components))
|
||||
return false;
|
||||
|
||||
unsigned high_offset = high->offset_signed - low->offset_signed;
|
||||
|
@@ -171,6 +171,12 @@ print_dest(nir_dest *dest, print_state *state)
|
||||
print_reg_dest(&dest->reg, state);
|
||||
}
|
||||
|
||||
static const char *
|
||||
comp_mask_string(unsigned num_components)
|
||||
{
|
||||
return (num_components > 4) ? "abcdefghijklmnop" : "xyzw";
|
||||
}
|
||||
|
||||
static void
|
||||
print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state)
|
||||
{
|
||||
@@ -206,7 +212,7 @@ print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state)
|
||||
if (!nir_alu_instr_channel_used(instr, src, i))
|
||||
continue;
|
||||
|
||||
fprintf(fp, "%c", "xyzw"[instr->src[src].swizzle[i]]);
|
||||
fprintf(fp, "%c", comp_mask_string(live_channels)[instr->src[src].swizzle[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,10 +230,11 @@ print_alu_dest(nir_alu_dest *dest, print_state *state)
|
||||
|
||||
if (!dest->dest.is_ssa &&
|
||||
dest->write_mask != (1 << dest->dest.reg.reg->num_components) - 1) {
|
||||
unsigned live_channels = dest->dest.reg.reg->num_components;
|
||||
fprintf(fp, ".");
|
||||
for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
|
||||
if ((dest->write_mask >> i) & 1)
|
||||
fprintf(fp, "%c", "xyzw"[i]);
|
||||
fprintf(fp, "%c", comp_mask_string(live_channels)[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -569,8 +576,8 @@ print_var_decl(nir_variable *var, print_state *state)
|
||||
switch (var->data.mode) {
|
||||
case nir_var_shader_in:
|
||||
case nir_var_shader_out:
|
||||
if (num_components < 4 && num_components != 0) {
|
||||
const char *xyzw = "xyzw";
|
||||
if (num_components < 16 && num_components != 0) {
|
||||
const char *xyzw = comp_mask_string(num_components);
|
||||
for (int i = 0; i < num_components; i++)
|
||||
components_local[i + 1] = xyzw[i + var->data.location_frac];
|
||||
|
||||
@@ -816,9 +823,9 @@ print_intrinsic_instr(nir_intrinsic_instr *instr, print_state *state)
|
||||
/* special case wrmask to show it as a writemask.. */
|
||||
unsigned wrmask = nir_intrinsic_write_mask(instr);
|
||||
fprintf(fp, " wrmask=");
|
||||
for (unsigned i = 0; i < 4; i++)
|
||||
for (unsigned i = 0; i < instr->num_components; i++)
|
||||
if ((wrmask >> i) & 1)
|
||||
fprintf(fp, "%c", "xyzw"[i]);
|
||||
fprintf(fp, "%c", comp_mask_string(instr->num_components)[i]);
|
||||
break;
|
||||
}
|
||||
|
||||
|
@@ -56,7 +56,13 @@ static bool
|
||||
nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
|
||||
const struct per_op_table *pass_op_table);
|
||||
|
||||
static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
|
||||
static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] =
|
||||
{
|
||||
0, 1, 2, 3,
|
||||
4, 5, 6, 7,
|
||||
8, 9, 10, 11,
|
||||
12, 13, 14, 15,
|
||||
};
|
||||
|
||||
/**
|
||||
* Check if a source produces a value of the given type.
|
||||
|
@@ -128,8 +128,7 @@ static void validate_src(nir_src *src, validate_state *state,
|
||||
static void
|
||||
validate_num_components(validate_state *state, unsigned num_components)
|
||||
{
|
||||
validate_assert(state, num_components >= 1 &&
|
||||
num_components <= 4);
|
||||
validate_assert(state, nir_num_components_valid(num_components));
|
||||
}
|
||||
|
||||
static void
|
||||
|
@@ -3819,10 +3819,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
|
||||
case SpvCapabilityInputAttachment:
|
||||
case SpvCapabilityImageGatherExtended:
|
||||
case SpvCapabilityStorageImageExtendedFormats:
|
||||
case SpvCapabilityVector16:
|
||||
break;
|
||||
|
||||
case SpvCapabilityLinkage:
|
||||
case SpvCapabilityVector16:
|
||||
case SpvCapabilityFloat16Buffer:
|
||||
case SpvCapabilitySparseResidency:
|
||||
vtn_warn("Unsupported SPIR-V capability: %s",
|
||||
|
Reference in New Issue
Block a user