vtn/opencl: Rework handle_instr to be able to handle core SPIR-V opcodes via libclc

The OpenCL async copy/wait opcodes are core SPIR-V, rather than OpenCL extension opcodes.

Reviewed-by: Dave Airlie <airlied@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6035>
This commit is contained in:
Jesse Natalie
2020-09-25 12:11:14 -07:00
committed by Marge Bot
parent b08fd45be0
commit 00261d883d

View File

@@ -31,7 +31,7 @@
#include "OpenCL.std.h"
typedef nir_ssa_def *(*nir_handler)(struct vtn_builder *b,
enum OpenCLstd_Entrypoints opcode,
uint32_t opcode,
unsigned num_srcs, nir_ssa_def **srcs,
struct vtn_type **src_types,
const struct vtn_type *dest_type);
@@ -205,27 +205,26 @@ static bool call_mangled_function(struct vtn_builder *b,
}
static void
handle_instr(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
const uint32_t *w, unsigned count, nir_handler handler)
handle_instr(struct vtn_builder *b, uint32_t opcode,
const uint32_t *w_src, unsigned num_srcs, const uint32_t *w_dest, nir_handler handler)
{
struct vtn_type *dest_type = vtn_get_type(b, w[1]);
struct vtn_type *dest_type = w_dest ? vtn_get_type(b, w_dest[0]) : NULL;
unsigned num_srcs = count - 5;
nir_ssa_def *srcs[3] = { NULL };
struct vtn_type *src_types[3] = { NULL };
nir_ssa_def *srcs[5] = { NULL };
struct vtn_type *src_types[5] = { NULL };
vtn_assert(num_srcs <= ARRAY_SIZE(srcs));
for (unsigned i = 0; i < num_srcs; i++) {
struct vtn_value *val = vtn_untyped_value(b, w[i + 5]);
struct vtn_ssa_value *ssa = vtn_ssa_value(b, w[i + 5]);
struct vtn_value *val = vtn_untyped_value(b, w_src[i]);
struct vtn_ssa_value *ssa = vtn_ssa_value(b, w_src[i]);
srcs[i] = ssa->def;
src_types[i] = val->type;
}
nir_ssa_def *result = handler(b, opcode, num_srcs, srcs, src_types, dest_type);
if (result) {
vtn_push_nir_ssa(b, w[2], result);
vtn_push_nir_ssa(b, w_dest[1], result);
} else {
vtn_assert(dest_type->type == glsl_void_type());
vtn_assert(dest_type == NULL);
}
}
@@ -286,11 +285,11 @@ nir_alu_op_for_opencl_opcode(struct vtn_builder *b,
}
static nir_ssa_def *
handle_alu(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
handle_alu(struct vtn_builder *b, uint32_t opcode,
unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
const struct vtn_type *dest_type)
{
nir_ssa_def *ret = nir_build_alu(&b->nb, nir_alu_op_for_opencl_opcode(b, opcode),
nir_ssa_def *ret = nir_build_alu(&b->nb, nir_alu_op_for_opencl_opcode(b, (enum OpenCLstd_Entrypoints)opcode),
srcs[0], srcs[1], srcs[2], NULL);
if (opcode == OpenCLstd_Popcount)
ret = nir_u2u(&b->nb, ret, glsl_get_bit_size(dest_type->type));
@@ -481,13 +480,14 @@ handle_clc_fn(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
}
static nir_ssa_def *
handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
handle_special(struct vtn_builder *b, uint32_t opcode,
unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
const struct vtn_type *dest_type)
{
nir_builder *nb = &b->nb;
enum OpenCLstd_Entrypoints cl_opcode = (enum OpenCLstd_Entrypoints)opcode;
switch (opcode) {
switch (cl_opcode) {
case OpenCLstd_SAbs_diff:
/* these works easier in direct NIR */
return nir_iabs_diff(nb, srcs[0], srcs[1]);
@@ -639,7 +639,7 @@ vtn_handle_opencl_vstore(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcod
}
static nir_ssa_def *
handle_printf(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
handle_printf(struct vtn_builder *b, uint32_t opcode,
unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
const struct vtn_type *dest_type)
{
@@ -648,8 +648,8 @@ handle_printf(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
}
static nir_ssa_def *
handle_round(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs,
nir_ssa_def **srcs, struct vtn_type **src_types,
handle_round(struct vtn_builder *b, uint32_t opcode,
unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
const struct vtn_type *dest_type)
{
nir_ssa_def *src = srcs[0];
@@ -663,8 +663,8 @@ handle_round(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned
}
static nir_ssa_def *
handle_shuffle(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs,
nir_ssa_def **srcs, struct vtn_type **src_types,
handle_shuffle(struct vtn_builder *b, uint32_t opcode,
unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
const struct vtn_type *dest_type)
{
struct nir_ssa_def *input = srcs[0];
@@ -683,8 +683,8 @@ handle_shuffle(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigne
}
static nir_ssa_def *
handle_shuffle2(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs,
nir_ssa_def **srcs, struct vtn_type **src_types,
handle_shuffle2(struct vtn_builder *b, uint32_t opcode,
unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
const struct vtn_type *dest_type)
{
struct nir_ssa_def *input0 = srcs[0];
@@ -762,7 +762,7 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode,
case OpenCLstd_USub_sat:
case OpenCLstd_Trunc:
case OpenCLstd_Rint:
handle_instr(b, cl_opcode, w, count, handle_alu);
handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_alu);
return true;
case OpenCLstd_SAbs_diff:
case OpenCLstd_UAbs_diff:
@@ -860,7 +860,7 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode,
case OpenCLstd_Half_powr:
case OpenCLstd_Half_sin:
case OpenCLstd_Half_tan:
handle_instr(b, cl_opcode, w, count, handle_special);
handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_special);
return true;
case OpenCLstd_Vloadn:
vtn_handle_opencl_vload(b, cl_opcode, w, count);
@@ -869,16 +869,16 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode,
vtn_handle_opencl_vstore(b, cl_opcode, w, count);
return true;
case OpenCLstd_Shuffle:
handle_instr(b, cl_opcode, w, count, handle_shuffle);
handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_shuffle);
return true;
case OpenCLstd_Shuffle2:
handle_instr(b, cl_opcode, w, count, handle_shuffle2);
handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_shuffle2);
return true;
case OpenCLstd_Round:
handle_instr(b, cl_opcode, w, count, handle_round);
handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_round);
return true;
case OpenCLstd_Printf:
handle_instr(b, cl_opcode, w, count, handle_printf);
handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_printf);
return true;
case OpenCLstd_Prefetch:
/* TODO maybe add a nir instruction for this? */