vtn/opencl: Add infrastructure for calling out to libclc
This patch adds a function remap table with name mangling, which can convert a SPIR-V OpenCL extension opcode to a call to the external libclc shader, which will be lowered/inlined after conversion. Reviewed-by: Dave Airlie <airlied@redhat.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6035>
This commit is contained in:
@@ -83,6 +83,8 @@ struct spirv_to_nir_options {
|
||||
nir_address_format temp_addr_format;
|
||||
nir_address_format constant_addr_format;
|
||||
|
||||
nir_shader *clc_shader;
|
||||
|
||||
struct {
|
||||
void (*func)(void *private_data,
|
||||
enum nir_spirv_debug_level level,
|
||||
|
@@ -36,6 +36,174 @@ typedef nir_ssa_def *(*nir_handler)(struct vtn_builder *b,
|
||||
struct vtn_type **src_types,
|
||||
const struct vtn_type *dest_type);
|
||||
|
||||
static int to_llvm_address_space(SpvStorageClass mode)
|
||||
{
|
||||
switch (mode) {
|
||||
case SpvStorageClassPrivate:
|
||||
case SpvStorageClassFunction: return 0;
|
||||
case SpvStorageClassCrossWorkgroup: return 1;
|
||||
case SpvStorageClassUniform:
|
||||
case SpvStorageClassUniformConstant: return 2;
|
||||
case SpvStorageClassWorkgroup: return 3;
|
||||
default: return -1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void
|
||||
vtn_opencl_mangle(const char *in_name,
|
||||
uint32_t const_mask,
|
||||
int ntypes, struct vtn_type **src_types,
|
||||
char **outstring)
|
||||
{
|
||||
char local_name[256] = "";
|
||||
char *args_str = local_name + sprintf(local_name, "_Z%zu%s", strlen(in_name), in_name);
|
||||
|
||||
for (unsigned i = 0; i < ntypes; ++i) {
|
||||
const struct glsl_type *type = src_types[i]->type;
|
||||
enum vtn_base_type base_type = src_types[i]->base_type;
|
||||
if (src_types[i]->base_type == vtn_base_type_pointer) {
|
||||
*(args_str++) = 'P';
|
||||
int address_space = to_llvm_address_space(src_types[i]->storage_class);
|
||||
if (address_space > 0)
|
||||
args_str += sprintf(args_str, "U3AS%d", address_space);
|
||||
|
||||
type = src_types[i]->deref->type;
|
||||
base_type = src_types[i]->deref->base_type;
|
||||
}
|
||||
|
||||
if (const_mask & (1 << i))
|
||||
*(args_str++) = 'K';
|
||||
|
||||
unsigned num_elements = glsl_get_components(type);
|
||||
if (num_elements > 1) {
|
||||
/* Vectors are not treated as built-ins for mangling, so check for substitution.
|
||||
* In theory, we'd need to know which substitution value this is. In practice,
|
||||
* the functions we need from libclc only support 1
|
||||
*/
|
||||
bool substitution = false;
|
||||
for (unsigned j = 0; j < i; ++j) {
|
||||
const struct glsl_type *other_type = src_types[j]->base_type == vtn_base_type_pointer ?
|
||||
src_types[j]->deref->type : src_types[j]->type;
|
||||
if (type == other_type) {
|
||||
substitution = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (substitution) {
|
||||
args_str += sprintf(args_str, "S_");
|
||||
continue;
|
||||
} else
|
||||
args_str += sprintf(args_str, "Dv%d_", num_elements);
|
||||
}
|
||||
|
||||
const char *suffix = NULL;
|
||||
switch (base_type) {
|
||||
case vtn_base_type_sampler: suffix = "11ocl_sampler"; break;
|
||||
case vtn_base_type_event: suffix = "9ocl_event"; break;
|
||||
default: {
|
||||
const char *primitives[] = {
|
||||
[GLSL_TYPE_UINT] = "j",
|
||||
[GLSL_TYPE_INT] = "i",
|
||||
[GLSL_TYPE_FLOAT] = "f",
|
||||
[GLSL_TYPE_FLOAT16] = "Dh",
|
||||
[GLSL_TYPE_DOUBLE] = "d",
|
||||
[GLSL_TYPE_UINT8] = "h",
|
||||
[GLSL_TYPE_INT8] = "c",
|
||||
[GLSL_TYPE_UINT16] = "t",
|
||||
[GLSL_TYPE_INT16] = "s",
|
||||
[GLSL_TYPE_UINT64] = "m",
|
||||
[GLSL_TYPE_INT64] = "l",
|
||||
[GLSL_TYPE_BOOL] = "b",
|
||||
[GLSL_TYPE_ERROR] = NULL,
|
||||
};
|
||||
enum glsl_base_type glsl_base_type = glsl_get_base_type(type);
|
||||
assert(glsl_base_type < ARRAY_SIZE(primitives) && primitives[glsl_base_type]);
|
||||
suffix = primitives[glsl_base_type];
|
||||
break;
|
||||
}
|
||||
}
|
||||
args_str += sprintf(args_str, "%s", suffix);
|
||||
}
|
||||
|
||||
*outstring = strdup(local_name);
|
||||
}
|
||||
|
||||
static nir_function *mangle_and_find(struct vtn_builder *b,
|
||||
const char *name,
|
||||
uint32_t const_mask,
|
||||
uint32_t num_srcs,
|
||||
struct vtn_type **src_types)
|
||||
{
|
||||
char *mname;
|
||||
nir_function *found = NULL;
|
||||
|
||||
vtn_opencl_mangle(name, const_mask, num_srcs, src_types, &mname);
|
||||
/* try and find in current shader first. */
|
||||
nir_foreach_function(funcs, b->shader) {
|
||||
if (!strcmp(funcs->name, mname)) {
|
||||
found = funcs;
|
||||
break;
|
||||
}
|
||||
}
|
||||
/* if not found here find in clc shader and create a decl mirroring it */
|
||||
if (!found && b->options->clc_shader && b->options->clc_shader != b->shader) {
|
||||
nir_foreach_function(funcs, b->options->clc_shader) {
|
||||
if (!strcmp(funcs->name, mname)) {
|
||||
found = funcs;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
nir_function *decl = nir_function_create(b->shader, mname);
|
||||
decl->num_params = found->num_params;
|
||||
decl->params = ralloc_array(b->shader, nir_parameter, decl->num_params);
|
||||
for (unsigned i = 0; i < decl->num_params; i++) {
|
||||
decl->params[i] = found->params[i];
|
||||
}
|
||||
found = decl;
|
||||
}
|
||||
}
|
||||
if (!found)
|
||||
vtn_fail("Can't find clc function %s\n", mname);
|
||||
free(mname);
|
||||
return found;
|
||||
}
|
||||
|
||||
static bool call_mangled_function(struct vtn_builder *b,
|
||||
const char *name,
|
||||
uint32_t const_mask,
|
||||
uint32_t num_srcs,
|
||||
struct vtn_type **src_types,
|
||||
const struct vtn_type *dest_type,
|
||||
nir_ssa_def **srcs,
|
||||
nir_deref_instr **ret_deref_ptr)
|
||||
{
|
||||
nir_function *found = mangle_and_find(b, name, const_mask, num_srcs, src_types);
|
||||
if (!found)
|
||||
return false;
|
||||
|
||||
nir_call_instr *call = nir_call_instr_create(b->shader, found);
|
||||
|
||||
nir_deref_instr *ret_deref = NULL;
|
||||
uint32_t param_idx = 0;
|
||||
if (dest_type) {
|
||||
nir_variable *ret_tmp = nir_local_variable_create(b->nb.impl,
|
||||
glsl_get_bare_type(dest_type->type),
|
||||
"return_tmp");
|
||||
ret_deref = nir_build_deref_var(&b->nb, ret_tmp);
|
||||
call->params[param_idx++] = nir_src_for_ssa(&ret_deref->dest.ssa);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < num_srcs; i++)
|
||||
call->params[param_idx++] = nir_src_for_ssa(srcs[i]);
|
||||
nir_builder_instr_insert(&b->nb, &call->instr);
|
||||
|
||||
*ret_deref_ptr = ret_deref;
|
||||
return true;
|
||||
}
|
||||
|
||||
static void
|
||||
handle_instr(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
|
||||
const uint32_t *w, unsigned count, nir_handler handler)
|
||||
@@ -129,6 +297,189 @@ handle_alu(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
|
||||
return ret;
|
||||
}
|
||||
|
||||
#define REMAP(op, str) [OpenCLstd_##op] = { str }
|
||||
static const struct {
|
||||
const char *fn;
|
||||
} remap_table[] = {
|
||||
REMAP(Distance, "distance"),
|
||||
REMAP(Fast_distance, "fast_distance"),
|
||||
REMAP(Fast_length, "fast_length"),
|
||||
REMAP(Fast_normalize, "fast_normalize"),
|
||||
REMAP(Half_rsqrt, "half_rsqrt"),
|
||||
REMAP(Half_sqrt, "half_sqrt"),
|
||||
REMAP(Length, "length"),
|
||||
REMAP(Normalize, "normalize"),
|
||||
REMAP(Degrees, "degrees"),
|
||||
REMAP(Radians, "radians"),
|
||||
REMAP(Rotate, "rotate"),
|
||||
REMAP(Smoothstep, "smoothstep"),
|
||||
REMAP(Step, "step"),
|
||||
|
||||
REMAP(Pow, "pow"),
|
||||
REMAP(Pown, "pown"),
|
||||
REMAP(Powr, "powr"),
|
||||
REMAP(Rootn, "rootn"),
|
||||
REMAP(Modf, "modf"),
|
||||
|
||||
REMAP(Acos, "acos"),
|
||||
REMAP(Acosh, "acosh"),
|
||||
REMAP(Acospi, "acospi"),
|
||||
REMAP(Asin, "asin"),
|
||||
REMAP(Asinh, "asinh"),
|
||||
REMAP(Asinpi, "asinpi"),
|
||||
REMAP(Atan, "atan"),
|
||||
REMAP(Atan2, "atan2"),
|
||||
REMAP(Atanh, "atanh"),
|
||||
REMAP(Atanpi, "atanpi"),
|
||||
REMAP(Atan2pi, "atan2pi"),
|
||||
REMAP(Cos, "cos"),
|
||||
REMAP(Cosh, "cosh"),
|
||||
REMAP(Cospi, "cospi"),
|
||||
REMAP(Sin, "sin"),
|
||||
REMAP(Sinh, "sinh"),
|
||||
REMAP(Sinpi, "sinpi"),
|
||||
REMAP(Tan, "tan"),
|
||||
REMAP(Tanh, "tanh"),
|
||||
REMAP(Tanpi, "tanpi"),
|
||||
REMAP(Sincos, "sincos"),
|
||||
REMAP(Fract, "fract"),
|
||||
REMAP(Frexp, "frexp"),
|
||||
REMAP(Fma, "fma"),
|
||||
REMAP(Fmod, "fmod"),
|
||||
|
||||
REMAP(Half_cos, "cos"),
|
||||
REMAP(Half_exp, "exp"),
|
||||
REMAP(Half_exp2, "exp2"),
|
||||
REMAP(Half_exp10, "exp10"),
|
||||
REMAP(Half_log, "log"),
|
||||
REMAP(Half_log2, "log2"),
|
||||
REMAP(Half_log10, "log10"),
|
||||
REMAP(Half_powr, "powr"),
|
||||
REMAP(Half_sin, "sin"),
|
||||
REMAP(Half_tan, "tan"),
|
||||
|
||||
REMAP(Remainder, "remainder"),
|
||||
REMAP(Remquo, "remquo"),
|
||||
REMAP(Hypot, "hypot"),
|
||||
REMAP(Exp, "exp"),
|
||||
REMAP(Exp2, "exp2"),
|
||||
REMAP(Exp10, "exp10"),
|
||||
REMAP(Expm1, "expm1"),
|
||||
REMAP(Ldexp, "ldexp"),
|
||||
|
||||
REMAP(Ilogb, "ilogb"),
|
||||
REMAP(Log, "log"),
|
||||
REMAP(Log2, "log2"),
|
||||
REMAP(Log10, "log10"),
|
||||
REMAP(Log1p, "log1p"),
|
||||
REMAP(Logb, "logb"),
|
||||
|
||||
REMAP(Cbrt, "cbrt"),
|
||||
REMAP(Erfc, "erfc"),
|
||||
REMAP(Erf, "erf"),
|
||||
|
||||
REMAP(Lgamma, "lgamma"),
|
||||
REMAP(Lgamma_r, "lgamma_r"),
|
||||
REMAP(Tgamma, "tgamma"),
|
||||
|
||||
REMAP(UMad_sat, "mad_sat"),
|
||||
REMAP(SMad_sat, "mad_sat"),
|
||||
|
||||
REMAP(Shuffle, "shuffle"),
|
||||
REMAP(Shuffle2, "shuffle2"),
|
||||
};
|
||||
#undef REMAP
|
||||
|
||||
static const char *remap_clc_opcode(enum OpenCLstd_Entrypoints opcode)
|
||||
{
|
||||
if (opcode >= (sizeof(remap_table) / sizeof(const char *)))
|
||||
return NULL;
|
||||
return remap_table[opcode].fn;
|
||||
}
|
||||
|
||||
static struct vtn_type *
|
||||
get_vtn_type_for_glsl_type(struct vtn_builder *b, const struct glsl_type *type)
|
||||
{
|
||||
struct vtn_type *ret = rzalloc(b, struct vtn_type);
|
||||
assert(glsl_type_is_vector_or_scalar(type));
|
||||
ret->type = type;
|
||||
ret->length = glsl_get_vector_elements(type);
|
||||
ret->base_type = glsl_type_is_vector(type) ? vtn_base_type_vector : vtn_base_type_scalar;
|
||||
return ret;
|
||||
}
|
||||
|
||||
static struct vtn_type *
|
||||
get_pointer_type(struct vtn_builder *b, struct vtn_type *t, SpvStorageClass storage_class)
|
||||
{
|
||||
struct vtn_type *ret = rzalloc(b, struct vtn_type);
|
||||
ret->type = nir_address_format_to_glsl_type(
|
||||
vtn_mode_to_address_format(
|
||||
b, vtn_storage_class_to_mode(b, storage_class, NULL, NULL)));
|
||||
ret->base_type = vtn_base_type_pointer;
|
||||
ret->storage_class = storage_class;
|
||||
ret->deref = t;
|
||||
return ret;
|
||||
}
|
||||
|
||||
static struct vtn_type *
|
||||
get_signed_type(struct vtn_builder *b, struct vtn_type *t)
|
||||
{
|
||||
if (t->base_type == vtn_base_type_pointer) {
|
||||
return get_pointer_type(b, get_signed_type(b, t->deref), t->storage_class);
|
||||
}
|
||||
return get_vtn_type_for_glsl_type(
|
||||
b, glsl_vector_type(glsl_signed_base_type_of(glsl_get_base_type(t->type)),
|
||||
glsl_get_vector_elements(t->type)));
|
||||
}
|
||||
|
||||
static nir_ssa_def *
|
||||
handle_clc_fn(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
|
||||
int num_srcs,
|
||||
nir_ssa_def **srcs,
|
||||
struct vtn_type **src_types,
|
||||
const struct vtn_type *dest_type)
|
||||
{
|
||||
const char *name = remap_clc_opcode(opcode);
|
||||
if (!name)
|
||||
return NULL;
|
||||
|
||||
/* Some functions which take params end up with uint (or pointer-to-uint) being passed,
|
||||
* which doesn't mangle correctly when the function expects int or pointer-to-int.
|
||||
* See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers
|
||||
*/
|
||||
int signed_param = -1;
|
||||
switch (opcode) {
|
||||
case OpenCLstd_Frexp:
|
||||
case OpenCLstd_Lgamma_r:
|
||||
case OpenCLstd_Pown:
|
||||
case OpenCLstd_Rootn:
|
||||
case OpenCLstd_Ldexp:
|
||||
signed_param = 1;
|
||||
break;
|
||||
case OpenCLstd_Remquo:
|
||||
signed_param = 2;
|
||||
break;
|
||||
case OpenCLstd_SMad_sat: {
|
||||
/* All parameters need to be converted to signed */
|
||||
src_types[0] = src_types[1] = src_types[2] = get_signed_type(b, src_types[0]);
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
|
||||
if (signed_param >= 0) {
|
||||
src_types[signed_param] = get_signed_type(b, src_types[signed_param]);
|
||||
}
|
||||
|
||||
nir_deref_instr *ret_deref = NULL;
|
||||
|
||||
if (!call_mangled_function(b, name, 0, num_srcs, src_types,
|
||||
dest_type, srcs, &ret_deref))
|
||||
return NULL;
|
||||
|
||||
return ret_deref ? nir_load_deref(&b->nb, ret_deref) : NULL;
|
||||
}
|
||||
|
||||
static nir_ssa_def *
|
||||
handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
|
||||
unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
|
||||
@@ -138,6 +489,7 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
|
||||
|
||||
switch (opcode) {
|
||||
case OpenCLstd_SAbs_diff:
|
||||
/* these works easier in direct NIR */
|
||||
return nir_iabs_diff(nb, srcs[0], srcs[1]);
|
||||
case OpenCLstd_UAbs_diff:
|
||||
return nir_uabs_diff(nb, srcs[0], srcs[1]);
|
||||
@@ -207,6 +559,7 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
|
||||
return nir_sge(nb, srcs[1], srcs[0]);
|
||||
case OpenCLstd_S_Upsample:
|
||||
case OpenCLstd_U_Upsample:
|
||||
/* SPIR-V and CL have different defs for upsample, just implement in nir */
|
||||
return nir_upsample(nb, srcs[0], srcs[1]);
|
||||
case OpenCLstd_Native_exp:
|
||||
return nir_fexp(nb, srcs[0]);
|
||||
@@ -219,9 +572,14 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
|
||||
case OpenCLstd_Native_tan:
|
||||
return nir_ftan(nb, srcs[0]);
|
||||
default:
|
||||
vtn_fail("No NIR equivalent");
|
||||
return NULL;
|
||||
break;
|
||||
}
|
||||
|
||||
nir_ssa_def *ret = handle_clc_fn(b, opcode, num_srcs, srcs, src_types, dest_type);
|
||||
if (!ret)
|
||||
vtn_fail("No NIR equivalent");
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static void
|
||||
|
Reference in New Issue
Block a user