diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index c891e7564d3..0c1ac862182 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -4133,7 +4133,6 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode, break; case SpvCapabilityLinkage: - case SpvCapabilityFloat16Buffer: case SpvCapabilitySparseResidency: vtn_warn("Unsupported SPIR-V capability: %s", spirv_capability_to_string(cap)); @@ -4181,6 +4180,7 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode, break; case SpvCapabilityKernel: + case SpvCapabilityFloat16Buffer: spv_check_supported(kernel, cap); break; diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index 205eca082c0..9daffbc7a4f 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -382,7 +382,7 @@ handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member, b->nb.exact = true; } -static nir_rounding_mode +nir_rounding_mode vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode) { switch (mode) { diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c index 13adb846d38..301cb112550 100644 --- a/src/compiler/spirv/vtn_opencl.c +++ b/src/compiler/spirv/vtn_opencl.c @@ -614,7 +614,8 @@ handle_core(struct vtn_builder *b, uint32_t opcode, static void _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, - const uint32_t *w, unsigned count, bool load) + const uint32_t *w, unsigned count, bool load, + bool vec_aligned, nir_rounding_mode rounding) { struct vtn_type *type; if (load) @@ -629,12 +630,28 @@ _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, nir_ssa_def *offset = vtn_get_nir_ssa(b, w[5 + a]); struct vtn_value *p = vtn_value(b, w[6 + a], vtn_value_type_pointer); + enum glsl_base_type ptr_base_type = + glsl_get_base_type(p->pointer->type->type); + if (base_type != ptr_base_type) { + vtn_fail_if(ptr_base_type != GLSL_TYPE_FLOAT16 || + (base_type != GLSL_TYPE_FLOAT && + base_type != GLSL_TYPE_DOUBLE), + "vload/vstore cannot do type conversion. " + "vload/vstore_half can only convert from half to other " + "floating-point types."); + } + struct vtn_ssa_value *comps[NIR_MAX_VEC_COMPONENTS]; nir_ssa_def *ncomps[NIR_MAX_VEC_COMPONENTS]; - nir_ssa_def *moffset = nir_imul_imm(&b->nb, offset, components); + nir_ssa_def *moffset = nir_imul_imm(&b->nb, offset, + (vec_aligned && components == 3) ? 4 : components); nir_deref_instr *deref = vtn_pointer_to_deref(b, p->pointer); + unsigned alignment = vec_aligned ? glsl_get_cl_alignment(type->type) : + glsl_get_bit_size(type->type) / 8; + deref = nir_alignment_deref_cast(&b->nb, deref, alignment, 0); + for (int i = 0; i < components; i++) { nir_ssa_def *coffset = nir_iadd_imm(&b->nb, moffset, i); nir_deref_instr *arr_deref = nir_build_deref_ptr_as_array(&b->nb, deref, coffset); @@ -642,10 +659,30 @@ _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, if (load) { comps[i] = vtn_local_load(b, arr_deref, p->type->access); ncomps[i] = comps[i]->def; + if (base_type != ptr_base_type) { + assert(ptr_base_type == GLSL_TYPE_FLOAT16 && + (base_type == GLSL_TYPE_FLOAT || + base_type == GLSL_TYPE_DOUBLE)); + ncomps[i] = nir_f2fN(&b->nb, ncomps[i], + glsl_base_type_get_bit_size(base_type)); + } } else { struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, glsl_scalar_type(base_type)); struct vtn_ssa_value *val = vtn_ssa_value(b, w[5]); ssa->def = nir_channel(&b->nb, val->def, i); + if (base_type != ptr_base_type) { + assert(ptr_base_type == GLSL_TYPE_FLOAT16 && + (base_type == GLSL_TYPE_FLOAT || + base_type == GLSL_TYPE_DOUBLE)); + if (rounding == nir_rounding_mode_undef) { + ssa->def = nir_f2f16(&b->nb, ssa->def); + } else { + ssa->def = nir_convert_alu_types(&b->nb, ssa->def, + nir_type_float, + nir_type_float16, + rounding, false); + } + } vtn_local_store(b, ssa, arr_deref, p->type->access); } } @@ -658,14 +695,27 @@ static void vtn_handle_opencl_vload(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, const uint32_t *w, unsigned count) { - _handle_v_load_store(b, opcode, w, count, true); + _handle_v_load_store(b, opcode, w, count, true, + opcode == OpenCLstd_Vloada_halfn, + nir_rounding_mode_undef); } static void vtn_handle_opencl_vstore(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, const uint32_t *w, unsigned count) { - _handle_v_load_store(b, opcode, w, count, false); + _handle_v_load_store(b, opcode, w, count, false, + opcode == OpenCLstd_Vstorea_halfn, + nir_rounding_mode_undef); +} + +static void +vtn_handle_opencl_vstore_half_r(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, + const uint32_t *w, unsigned count) +{ + _handle_v_load_store(b, opcode, w, count, false, + opcode == OpenCLstd_Vstorea_halfn_r, + vtn_rounding_mode_to_nir(b, w[8])); } static nir_ssa_def * @@ -895,11 +945,22 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode, handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_special); return true; case OpenCLstd_Vloadn: + case OpenCLstd_Vload_half: + case OpenCLstd_Vload_halfn: + case OpenCLstd_Vloada_halfn: vtn_handle_opencl_vload(b, cl_opcode, w, count); return true; case OpenCLstd_Vstoren: + case OpenCLstd_Vstore_half: + case OpenCLstd_Vstore_halfn: + case OpenCLstd_Vstorea_halfn: vtn_handle_opencl_vstore(b, cl_opcode, w, count); return true; + case OpenCLstd_Vstore_half_r: + case OpenCLstd_Vstore_halfn_r: + case OpenCLstd_Vstorea_halfn_r: + vtn_handle_opencl_vstore_half_r(b, cl_opcode, w, count); + return true; case OpenCLstd_Shuffle: handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_shuffle); return true; diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index 54c7b11f3e3..f5f4ce8a9bb 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -900,6 +900,9 @@ enum vtn_variable_mode vtn_storage_class_to_mode(struct vtn_builder *b, nir_address_format vtn_mode_to_address_format(struct vtn_builder *b, enum vtn_variable_mode); +nir_rounding_mode vtn_rounding_mode_to_nir(struct vtn_builder *b, + SpvFPRoundingMode mode); + static inline uint32_t vtn_align_u32(uint32_t v, uint32_t a) {