vtn/opencl: Add infrastructure for calling out to libclc

This patch adds a function remap table with name mangling, which
can convert a SPIR-V OpenCL extension opcode to a call to the external
libclc shader, which will be lowered/inlined after conversion.

Reviewed-by: Dave Airlie <airlied@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6035>
This commit is contained in:
Jesse Natalie
2020-08-18 07:16:32 -07:00
committed by Marge Bot
parent 45d43ad2b8
commit 6436e3ac18
2 changed files with 362 additions and 2 deletions

View File

@@ -83,6 +83,8 @@ struct spirv_to_nir_options {
nir_address_format temp_addr_format;
nir_address_format constant_addr_format;
nir_shader *clc_shader;
struct {
void (*func)(void *private_data,
enum nir_spirv_debug_level level,

View File

@@ -36,6 +36,174 @@ typedef nir_ssa_def *(*nir_handler)(struct vtn_builder *b,
struct vtn_type **src_types,
const struct vtn_type *dest_type);
static int to_llvm_address_space(SpvStorageClass mode)
{
switch (mode) {
case SpvStorageClassPrivate:
case SpvStorageClassFunction: return 0;
case SpvStorageClassCrossWorkgroup: return 1;
case SpvStorageClassUniform:
case SpvStorageClassUniformConstant: return 2;
case SpvStorageClassWorkgroup: return 3;
default: return -1;
}
}
static void
vtn_opencl_mangle(const char *in_name,
uint32_t const_mask,
int ntypes, struct vtn_type **src_types,
char **outstring)
{
char local_name[256] = "";
char *args_str = local_name + sprintf(local_name, "_Z%zu%s", strlen(in_name), in_name);
for (unsigned i = 0; i < ntypes; ++i) {
const struct glsl_type *type = src_types[i]->type;
enum vtn_base_type base_type = src_types[i]->base_type;
if (src_types[i]->base_type == vtn_base_type_pointer) {
*(args_str++) = 'P';
int address_space = to_llvm_address_space(src_types[i]->storage_class);
if (address_space > 0)
args_str += sprintf(args_str, "U3AS%d", address_space);
type = src_types[i]->deref->type;
base_type = src_types[i]->deref->base_type;
}
if (const_mask & (1 << i))
*(args_str++) = 'K';
unsigned num_elements = glsl_get_components(type);
if (num_elements > 1) {
/* Vectors are not treated as built-ins for mangling, so check for substitution.
* In theory, we'd need to know which substitution value this is. In practice,
* the functions we need from libclc only support 1
*/
bool substitution = false;
for (unsigned j = 0; j < i; ++j) {
const struct glsl_type *other_type = src_types[j]->base_type == vtn_base_type_pointer ?
src_types[j]->deref->type : src_types[j]->type;
if (type == other_type) {
substitution = true;
break;
}
}
if (substitution) {
args_str += sprintf(args_str, "S_");
continue;
} else
args_str += sprintf(args_str, "Dv%d_", num_elements);
}
const char *suffix = NULL;
switch (base_type) {
case vtn_base_type_sampler: suffix = "11ocl_sampler"; break;
case vtn_base_type_event: suffix = "9ocl_event"; break;
default: {
const char *primitives[] = {
[GLSL_TYPE_UINT] = "j",
[GLSL_TYPE_INT] = "i",
[GLSL_TYPE_FLOAT] = "f",
[GLSL_TYPE_FLOAT16] = "Dh",
[GLSL_TYPE_DOUBLE] = "d",
[GLSL_TYPE_UINT8] = "h",
[GLSL_TYPE_INT8] = "c",
[GLSL_TYPE_UINT16] = "t",
[GLSL_TYPE_INT16] = "s",
[GLSL_TYPE_UINT64] = "m",
[GLSL_TYPE_INT64] = "l",
[GLSL_TYPE_BOOL] = "b",
[GLSL_TYPE_ERROR] = NULL,
};
enum glsl_base_type glsl_base_type = glsl_get_base_type(type);
assert(glsl_base_type < ARRAY_SIZE(primitives) && primitives[glsl_base_type]);
suffix = primitives[glsl_base_type];
break;
}
}
args_str += sprintf(args_str, "%s", suffix);
}
*outstring = strdup(local_name);
}
static nir_function *mangle_and_find(struct vtn_builder *b,
const char *name,
uint32_t const_mask,
uint32_t num_srcs,
struct vtn_type **src_types)
{
char *mname;
nir_function *found = NULL;
vtn_opencl_mangle(name, const_mask, num_srcs, src_types, &mname);
/* try and find in current shader first. */
nir_foreach_function(funcs, b->shader) {
if (!strcmp(funcs->name, mname)) {
found = funcs;
break;
}
}
/* if not found here find in clc shader and create a decl mirroring it */
if (!found && b->options->clc_shader && b->options->clc_shader != b->shader) {
nir_foreach_function(funcs, b->options->clc_shader) {
if (!strcmp(funcs->name, mname)) {
found = funcs;
break;
}
}
if (found) {
nir_function *decl = nir_function_create(b->shader, mname);
decl->num_params = found->num_params;
decl->params = ralloc_array(b->shader, nir_parameter, decl->num_params);
for (unsigned i = 0; i < decl->num_params; i++) {
decl->params[i] = found->params[i];
}
found = decl;
}
}
if (!found)
vtn_fail("Can't find clc function %s\n", mname);
free(mname);
return found;
}
static bool call_mangled_function(struct vtn_builder *b,
const char *name,
uint32_t const_mask,
uint32_t num_srcs,
struct vtn_type **src_types,
const struct vtn_type *dest_type,
nir_ssa_def **srcs,
nir_deref_instr **ret_deref_ptr)
{
nir_function *found = mangle_and_find(b, name, const_mask, num_srcs, src_types);
if (!found)
return false;
nir_call_instr *call = nir_call_instr_create(b->shader, found);
nir_deref_instr *ret_deref = NULL;
uint32_t param_idx = 0;
if (dest_type) {
nir_variable *ret_tmp = nir_local_variable_create(b->nb.impl,
glsl_get_bare_type(dest_type->type),
"return_tmp");
ret_deref = nir_build_deref_var(&b->nb, ret_tmp);
call->params[param_idx++] = nir_src_for_ssa(&ret_deref->dest.ssa);
}
for (unsigned i = 0; i < num_srcs; i++)
call->params[param_idx++] = nir_src_for_ssa(srcs[i]);
nir_builder_instr_insert(&b->nb, &call->instr);
*ret_deref_ptr = ret_deref;
return true;
}
static void
handle_instr(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
const uint32_t *w, unsigned count, nir_handler handler)
@@ -129,6 +297,189 @@ handle_alu(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
return ret;
}
#define REMAP(op, str) [OpenCLstd_##op] = { str }
static const struct {
const char *fn;
} remap_table[] = {
REMAP(Distance, "distance"),
REMAP(Fast_distance, "fast_distance"),
REMAP(Fast_length, "fast_length"),
REMAP(Fast_normalize, "fast_normalize"),
REMAP(Half_rsqrt, "half_rsqrt"),
REMAP(Half_sqrt, "half_sqrt"),
REMAP(Length, "length"),
REMAP(Normalize, "normalize"),
REMAP(Degrees, "degrees"),
REMAP(Radians, "radians"),
REMAP(Rotate, "rotate"),
REMAP(Smoothstep, "smoothstep"),
REMAP(Step, "step"),
REMAP(Pow, "pow"),
REMAP(Pown, "pown"),
REMAP(Powr, "powr"),
REMAP(Rootn, "rootn"),
REMAP(Modf, "modf"),
REMAP(Acos, "acos"),
REMAP(Acosh, "acosh"),
REMAP(Acospi, "acospi"),
REMAP(Asin, "asin"),
REMAP(Asinh, "asinh"),
REMAP(Asinpi, "asinpi"),
REMAP(Atan, "atan"),
REMAP(Atan2, "atan2"),
REMAP(Atanh, "atanh"),
REMAP(Atanpi, "atanpi"),
REMAP(Atan2pi, "atan2pi"),
REMAP(Cos, "cos"),
REMAP(Cosh, "cosh"),
REMAP(Cospi, "cospi"),
REMAP(Sin, "sin"),
REMAP(Sinh, "sinh"),
REMAP(Sinpi, "sinpi"),
REMAP(Tan, "tan"),
REMAP(Tanh, "tanh"),
REMAP(Tanpi, "tanpi"),
REMAP(Sincos, "sincos"),
REMAP(Fract, "fract"),
REMAP(Frexp, "frexp"),
REMAP(Fma, "fma"),
REMAP(Fmod, "fmod"),
REMAP(Half_cos, "cos"),
REMAP(Half_exp, "exp"),
REMAP(Half_exp2, "exp2"),
REMAP(Half_exp10, "exp10"),
REMAP(Half_log, "log"),
REMAP(Half_log2, "log2"),
REMAP(Half_log10, "log10"),
REMAP(Half_powr, "powr"),
REMAP(Half_sin, "sin"),
REMAP(Half_tan, "tan"),
REMAP(Remainder, "remainder"),
REMAP(Remquo, "remquo"),
REMAP(Hypot, "hypot"),
REMAP(Exp, "exp"),
REMAP(Exp2, "exp2"),
REMAP(Exp10, "exp10"),
REMAP(Expm1, "expm1"),
REMAP(Ldexp, "ldexp"),
REMAP(Ilogb, "ilogb"),
REMAP(Log, "log"),
REMAP(Log2, "log2"),
REMAP(Log10, "log10"),
REMAP(Log1p, "log1p"),
REMAP(Logb, "logb"),
REMAP(Cbrt, "cbrt"),
REMAP(Erfc, "erfc"),
REMAP(Erf, "erf"),
REMAP(Lgamma, "lgamma"),
REMAP(Lgamma_r, "lgamma_r"),
REMAP(Tgamma, "tgamma"),
REMAP(UMad_sat, "mad_sat"),
REMAP(SMad_sat, "mad_sat"),
REMAP(Shuffle, "shuffle"),
REMAP(Shuffle2, "shuffle2"),
};
#undef REMAP
static const char *remap_clc_opcode(enum OpenCLstd_Entrypoints opcode)
{
if (opcode >= (sizeof(remap_table) / sizeof(const char *)))
return NULL;
return remap_table[opcode].fn;
}
static struct vtn_type *
get_vtn_type_for_glsl_type(struct vtn_builder *b, const struct glsl_type *type)
{
struct vtn_type *ret = rzalloc(b, struct vtn_type);
assert(glsl_type_is_vector_or_scalar(type));
ret->type = type;
ret->length = glsl_get_vector_elements(type);
ret->base_type = glsl_type_is_vector(type) ? vtn_base_type_vector : vtn_base_type_scalar;
return ret;
}
static struct vtn_type *
get_pointer_type(struct vtn_builder *b, struct vtn_type *t, SpvStorageClass storage_class)
{
struct vtn_type *ret = rzalloc(b, struct vtn_type);
ret->type = nir_address_format_to_glsl_type(
vtn_mode_to_address_format(
b, vtn_storage_class_to_mode(b, storage_class, NULL, NULL)));
ret->base_type = vtn_base_type_pointer;
ret->storage_class = storage_class;
ret->deref = t;
return ret;
}
static struct vtn_type *
get_signed_type(struct vtn_builder *b, struct vtn_type *t)
{
if (t->base_type == vtn_base_type_pointer) {
return get_pointer_type(b, get_signed_type(b, t->deref), t->storage_class);
}
return get_vtn_type_for_glsl_type(
b, glsl_vector_type(glsl_signed_base_type_of(glsl_get_base_type(t->type)),
glsl_get_vector_elements(t->type)));
}
static nir_ssa_def *
handle_clc_fn(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
int num_srcs,
nir_ssa_def **srcs,
struct vtn_type **src_types,
const struct vtn_type *dest_type)
{
const char *name = remap_clc_opcode(opcode);
if (!name)
return NULL;
/* Some functions which take params end up with uint (or pointer-to-uint) being passed,
* which doesn't mangle correctly when the function expects int or pointer-to-int.
* See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers
*/
int signed_param = -1;
switch (opcode) {
case OpenCLstd_Frexp:
case OpenCLstd_Lgamma_r:
case OpenCLstd_Pown:
case OpenCLstd_Rootn:
case OpenCLstd_Ldexp:
signed_param = 1;
break;
case OpenCLstd_Remquo:
signed_param = 2;
break;
case OpenCLstd_SMad_sat: {
/* All parameters need to be converted to signed */
src_types[0] = src_types[1] = src_types[2] = get_signed_type(b, src_types[0]);
break;
}
default: break;
}
if (signed_param >= 0) {
src_types[signed_param] = get_signed_type(b, src_types[signed_param]);
}
nir_deref_instr *ret_deref = NULL;
if (!call_mangled_function(b, name, 0, num_srcs, src_types,
dest_type, srcs, &ret_deref))
return NULL;
return ret_deref ? nir_load_deref(&b->nb, ret_deref) : NULL;
}
static nir_ssa_def *
handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
@@ -138,6 +489,7 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
switch (opcode) {
case OpenCLstd_SAbs_diff:
/* these works easier in direct NIR */
return nir_iabs_diff(nb, srcs[0], srcs[1]);
case OpenCLstd_UAbs_diff:
return nir_uabs_diff(nb, srcs[0], srcs[1]);
@@ -207,6 +559,7 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
return nir_sge(nb, srcs[1], srcs[0]);
case OpenCLstd_S_Upsample:
case OpenCLstd_U_Upsample:
/* SPIR-V and CL have different defs for upsample, just implement in nir */
return nir_upsample(nb, srcs[0], srcs[1]);
case OpenCLstd_Native_exp:
return nir_fexp(nb, srcs[0]);
@@ -219,9 +572,14 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
case OpenCLstd_Native_tan:
return nir_ftan(nb, srcs[0]);
default:
vtn_fail("No NIR equivalent");
return NULL;
break;
}
nir_ssa_def *ret = handle_clc_fn(b, opcode, num_srcs, srcs, src_types, dest_type);
if (!ret)
vtn_fail("No NIR equivalent");
return ret;
}
static void