spirv: Implement vload[a]_half[n] and vstore[a]_half[n][_r]

Note, the aligned versions aren't handled specially yet.

The float16buffer capability is now at least partially supported after
this patch, so move it to be supported when kernels are supported.

v2 (Jason Ekstrand):
 - A few cosmetic cleanups around type/base_type
 - Rebased on top of the big SPIR-V SSA value rework
 - Use the new version of the conversion helpers

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6945>
This commit is contained in:
Jesse Natalie
2020-07-30 16:45:46 -07:00
committed by Marge Bot
parent a85afb797e
commit 7d97f3dfdc
4 changed files with 70 additions and 6 deletions

View File

@@ -4133,7 +4133,6 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
break;
case SpvCapabilityLinkage:
case SpvCapabilityFloat16Buffer:
case SpvCapabilitySparseResidency:
vtn_warn("Unsupported SPIR-V capability: %s",
spirv_capability_to_string(cap));
@@ -4181,6 +4180,7 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
break;
case SpvCapabilityKernel:
case SpvCapabilityFloat16Buffer:
spv_check_supported(kernel, cap);
break;

View File

@@ -382,7 +382,7 @@ handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
b->nb.exact = true;
}
static nir_rounding_mode
nir_rounding_mode
vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
{
switch (mode) {

View File

@@ -614,7 +614,8 @@ handle_core(struct vtn_builder *b, uint32_t opcode,
static void
_handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
const uint32_t *w, unsigned count, bool load)
const uint32_t *w, unsigned count, bool load,
bool vec_aligned, nir_rounding_mode rounding)
{
struct vtn_type *type;
if (load)
@@ -629,12 +630,28 @@ _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
nir_ssa_def *offset = vtn_get_nir_ssa(b, w[5 + a]);
struct vtn_value *p = vtn_value(b, w[6 + a], vtn_value_type_pointer);
enum glsl_base_type ptr_base_type =
glsl_get_base_type(p->pointer->type->type);
if (base_type != ptr_base_type) {
vtn_fail_if(ptr_base_type != GLSL_TYPE_FLOAT16 ||
(base_type != GLSL_TYPE_FLOAT &&
base_type != GLSL_TYPE_DOUBLE),
"vload/vstore cannot do type conversion. "
"vload/vstore_half can only convert from half to other "
"floating-point types.");
}
struct vtn_ssa_value *comps[NIR_MAX_VEC_COMPONENTS];
nir_ssa_def *ncomps[NIR_MAX_VEC_COMPONENTS];
nir_ssa_def *moffset = nir_imul_imm(&b->nb, offset, components);
nir_ssa_def *moffset = nir_imul_imm(&b->nb, offset,
(vec_aligned && components == 3) ? 4 : components);
nir_deref_instr *deref = vtn_pointer_to_deref(b, p->pointer);
unsigned alignment = vec_aligned ? glsl_get_cl_alignment(type->type) :
glsl_get_bit_size(type->type) / 8;
deref = nir_alignment_deref_cast(&b->nb, deref, alignment, 0);
for (int i = 0; i < components; i++) {
nir_ssa_def *coffset = nir_iadd_imm(&b->nb, moffset, i);
nir_deref_instr *arr_deref = nir_build_deref_ptr_as_array(&b->nb, deref, coffset);
@@ -642,10 +659,30 @@ _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
if (load) {
comps[i] = vtn_local_load(b, arr_deref, p->type->access);
ncomps[i] = comps[i]->def;
if (base_type != ptr_base_type) {
assert(ptr_base_type == GLSL_TYPE_FLOAT16 &&
(base_type == GLSL_TYPE_FLOAT ||
base_type == GLSL_TYPE_DOUBLE));
ncomps[i] = nir_f2fN(&b->nb, ncomps[i],
glsl_base_type_get_bit_size(base_type));
}
} else {
struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, glsl_scalar_type(base_type));
struct vtn_ssa_value *val = vtn_ssa_value(b, w[5]);
ssa->def = nir_channel(&b->nb, val->def, i);
if (base_type != ptr_base_type) {
assert(ptr_base_type == GLSL_TYPE_FLOAT16 &&
(base_type == GLSL_TYPE_FLOAT ||
base_type == GLSL_TYPE_DOUBLE));
if (rounding == nir_rounding_mode_undef) {
ssa->def = nir_f2f16(&b->nb, ssa->def);
} else {
ssa->def = nir_convert_alu_types(&b->nb, ssa->def,
nir_type_float,
nir_type_float16,
rounding, false);
}
}
vtn_local_store(b, ssa, arr_deref, p->type->access);
}
}
@@ -658,14 +695,27 @@ static void
vtn_handle_opencl_vload(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
const uint32_t *w, unsigned count)
{
_handle_v_load_store(b, opcode, w, count, true);
_handle_v_load_store(b, opcode, w, count, true,
opcode == OpenCLstd_Vloada_halfn,
nir_rounding_mode_undef);
}
static void
vtn_handle_opencl_vstore(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
const uint32_t *w, unsigned count)
{
_handle_v_load_store(b, opcode, w, count, false);
_handle_v_load_store(b, opcode, w, count, false,
opcode == OpenCLstd_Vstorea_halfn,
nir_rounding_mode_undef);
}
static void
vtn_handle_opencl_vstore_half_r(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
const uint32_t *w, unsigned count)
{
_handle_v_load_store(b, opcode, w, count, false,
opcode == OpenCLstd_Vstorea_halfn_r,
vtn_rounding_mode_to_nir(b, w[8]));
}
static nir_ssa_def *
@@ -895,11 +945,22 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode,
handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_special);
return true;
case OpenCLstd_Vloadn:
case OpenCLstd_Vload_half:
case OpenCLstd_Vload_halfn:
case OpenCLstd_Vloada_halfn:
vtn_handle_opencl_vload(b, cl_opcode, w, count);
return true;
case OpenCLstd_Vstoren:
case OpenCLstd_Vstore_half:
case OpenCLstd_Vstore_halfn:
case OpenCLstd_Vstorea_halfn:
vtn_handle_opencl_vstore(b, cl_opcode, w, count);
return true;
case OpenCLstd_Vstore_half_r:
case OpenCLstd_Vstore_halfn_r:
case OpenCLstd_Vstorea_halfn_r:
vtn_handle_opencl_vstore_half_r(b, cl_opcode, w, count);
return true;
case OpenCLstd_Shuffle:
handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_shuffle);
return true;

View File

@@ -900,6 +900,9 @@ enum vtn_variable_mode vtn_storage_class_to_mode(struct vtn_builder *b,
nir_address_format vtn_mode_to_address_format(struct vtn_builder *b,
enum vtn_variable_mode);
nir_rounding_mode vtn_rounding_mode_to_nir(struct vtn_builder *b,
SpvFPRoundingMode mode);
static inline uint32_t
vtn_align_u32(uint32_t v, uint32_t a)
{