spirv/nir/opencl: handle some multiply instructions.

This adds support for some missing 24-bit and hi multiply
variants.

Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
This commit is contained in:
Dave Airlie
2019-04-30 06:57:11 +10:00
parent 5375c30234
commit 12913bcf86
2 changed files with 55 additions and 0 deletions

View File

@@ -82,6 +82,43 @@ nir_uabs_diff(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
return nir_bcsel(b, cond, res0, res1);
}
static inline nir_ssa_def *
nir_umul24(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
{
nir_ssa_def *mask = nir_imm_int(b, 0xffffff);
nir_ssa_def *x_24 = nir_iand(b, x, mask);
nir_ssa_def *y_24 = nir_iand(b, y, mask);
return nir_imul(b, x_24, y_24);
}
static inline nir_ssa_def *
nir_umad24(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *z)
{
nir_ssa_def *temp = nir_umul24(b, x, y);
return nir_iadd(b, temp, z);
}
static inline nir_ssa_def *
nir_imad24(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *z)
{
nir_ssa_def *temp = nir_imul24(b, x, y);
return nir_iadd(b, temp, z);
}
static inline nir_ssa_def *
nir_imad_hi(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *z)
{
nir_ssa_def *temp = nir_imul_high(b, x, y);
return nir_iadd(b, temp, z);
}
static inline nir_ssa_def *
nir_umad_hi(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *z)
{
nir_ssa_def *temp = nir_umul_high(b, x, y);
return nir_iadd(b, temp, z);
}
static inline nir_ssa_def *
nir_bitselect(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *s)
{

View File

@@ -129,6 +129,18 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
return nir_uabs_diff(nb, srcs[0], srcs[1]);
case OpenCLstd_Bitselect:
return nir_bitselect(nb, srcs[0], srcs[1], srcs[2]);
case OpenCLstd_SMad_hi:
return nir_imad_hi(nb, srcs[0], srcs[1], srcs[2]);
case OpenCLstd_UMad_hi:
return nir_umad_hi(nb, srcs[0], srcs[1], srcs[2]);
case OpenCLstd_SMul24:
return nir_imul24(nb, srcs[0], srcs[1]);
case OpenCLstd_UMul24:
return nir_umul24(nb, srcs[0], srcs[1]);
case OpenCLstd_SMad24:
return nir_imad24(nb, srcs[0], srcs[1], srcs[2]);
case OpenCLstd_UMad24:
return nir_umad24(nb, srcs[0], srcs[1], srcs[2]);
case OpenCLstd_FClamp:
return nir_fclamp(nb, srcs[0], srcs[1], srcs[2]);
case OpenCLstd_SClamp:
@@ -288,6 +300,12 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode,
return true;
case OpenCLstd_SAbs_diff:
case OpenCLstd_UAbs_diff:
case OpenCLstd_SMad_hi:
case OpenCLstd_UMad_hi:
case OpenCLstd_SMad24:
case OpenCLstd_UMad24:
case OpenCLstd_SMul24:
case OpenCLstd_UMul24:
case OpenCLstd_Bitselect:
case OpenCLstd_FClamp:
case OpenCLstd_SClamp: