diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c index d82bdf4524e..28c65499d85 100644 --- a/src/compiler/spirv/vtn_opencl.c +++ b/src/compiler/spirv/vtn_opencl.c @@ -31,7 +31,7 @@ #include "OpenCL.std.h" typedef nir_ssa_def *(*nir_handler)(struct vtn_builder *b, - enum OpenCLstd_Entrypoints opcode, + uint32_t opcode, unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types, const struct vtn_type *dest_type); @@ -205,27 +205,26 @@ static bool call_mangled_function(struct vtn_builder *b, } static void -handle_instr(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, - const uint32_t *w, unsigned count, nir_handler handler) +handle_instr(struct vtn_builder *b, uint32_t opcode, + const uint32_t *w_src, unsigned num_srcs, const uint32_t *w_dest, nir_handler handler) { - struct vtn_type *dest_type = vtn_get_type(b, w[1]); + struct vtn_type *dest_type = w_dest ? vtn_get_type(b, w_dest[0]) : NULL; - unsigned num_srcs = count - 5; - nir_ssa_def *srcs[3] = { NULL }; - struct vtn_type *src_types[3] = { NULL }; + nir_ssa_def *srcs[5] = { NULL }; + struct vtn_type *src_types[5] = { NULL }; vtn_assert(num_srcs <= ARRAY_SIZE(srcs)); for (unsigned i = 0; i < num_srcs; i++) { - struct vtn_value *val = vtn_untyped_value(b, w[i + 5]); - struct vtn_ssa_value *ssa = vtn_ssa_value(b, w[i + 5]); + struct vtn_value *val = vtn_untyped_value(b, w_src[i]); + struct vtn_ssa_value *ssa = vtn_ssa_value(b, w_src[i]); srcs[i] = ssa->def; src_types[i] = val->type; } nir_ssa_def *result = handler(b, opcode, num_srcs, srcs, src_types, dest_type); if (result) { - vtn_push_nir_ssa(b, w[2], result); + vtn_push_nir_ssa(b, w_dest[1], result); } else { - vtn_assert(dest_type->type == glsl_void_type()); + vtn_assert(dest_type == NULL); } } @@ -286,11 +285,11 @@ nir_alu_op_for_opencl_opcode(struct vtn_builder *b, } static nir_ssa_def * -handle_alu(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, +handle_alu(struct vtn_builder *b, uint32_t opcode, unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types, const struct vtn_type *dest_type) { - nir_ssa_def *ret = nir_build_alu(&b->nb, nir_alu_op_for_opencl_opcode(b, opcode), + nir_ssa_def *ret = nir_build_alu(&b->nb, nir_alu_op_for_opencl_opcode(b, (enum OpenCLstd_Entrypoints)opcode), srcs[0], srcs[1], srcs[2], NULL); if (opcode == OpenCLstd_Popcount) ret = nir_u2u(&b->nb, ret, glsl_get_bit_size(dest_type->type)); @@ -481,13 +480,14 @@ handle_clc_fn(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, } static nir_ssa_def * -handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, +handle_special(struct vtn_builder *b, uint32_t opcode, unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types, const struct vtn_type *dest_type) { nir_builder *nb = &b->nb; + enum OpenCLstd_Entrypoints cl_opcode = (enum OpenCLstd_Entrypoints)opcode; - switch (opcode) { + switch (cl_opcode) { case OpenCLstd_SAbs_diff: /* these works easier in direct NIR */ return nir_iabs_diff(nb, srcs[0], srcs[1]); @@ -639,7 +639,7 @@ vtn_handle_opencl_vstore(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcod } static nir_ssa_def * -handle_printf(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, +handle_printf(struct vtn_builder *b, uint32_t opcode, unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types, const struct vtn_type *dest_type) { @@ -648,8 +648,8 @@ handle_printf(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, } static nir_ssa_def * -handle_round(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs, - nir_ssa_def **srcs, struct vtn_type **src_types, +handle_round(struct vtn_builder *b, uint32_t opcode, + unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types, const struct vtn_type *dest_type) { nir_ssa_def *src = srcs[0]; @@ -663,8 +663,8 @@ handle_round(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned } static nir_ssa_def * -handle_shuffle(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs, - nir_ssa_def **srcs, struct vtn_type **src_types, +handle_shuffle(struct vtn_builder *b, uint32_t opcode, + unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types, const struct vtn_type *dest_type) { struct nir_ssa_def *input = srcs[0]; @@ -683,8 +683,8 @@ handle_shuffle(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigne } static nir_ssa_def * -handle_shuffle2(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs, - nir_ssa_def **srcs, struct vtn_type **src_types, +handle_shuffle2(struct vtn_builder *b, uint32_t opcode, + unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types, const struct vtn_type *dest_type) { struct nir_ssa_def *input0 = srcs[0]; @@ -762,7 +762,7 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode, case OpenCLstd_USub_sat: case OpenCLstd_Trunc: case OpenCLstd_Rint: - handle_instr(b, cl_opcode, w, count, handle_alu); + handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_alu); return true; case OpenCLstd_SAbs_diff: case OpenCLstd_UAbs_diff: @@ -860,7 +860,7 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode, case OpenCLstd_Half_powr: case OpenCLstd_Half_sin: case OpenCLstd_Half_tan: - handle_instr(b, cl_opcode, w, count, handle_special); + handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_special); return true; case OpenCLstd_Vloadn: vtn_handle_opencl_vload(b, cl_opcode, w, count); @@ -869,16 +869,16 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode, vtn_handle_opencl_vstore(b, cl_opcode, w, count); return true; case OpenCLstd_Shuffle: - handle_instr(b, cl_opcode, w, count, handle_shuffle); + handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_shuffle); return true; case OpenCLstd_Shuffle2: - handle_instr(b, cl_opcode, w, count, handle_shuffle2); + handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_shuffle2); return true; case OpenCLstd_Round: - handle_instr(b, cl_opcode, w, count, handle_round); + handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_round); return true; case OpenCLstd_Printf: - handle_instr(b, cl_opcode, w, count, handle_printf); + handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_printf); return true; case OpenCLstd_Prefetch: /* TODO maybe add a nir instruction for this? */