microsoft/compiler: Implement more wave/quad ops

This handles ballot, vote, shuffle, broadcast, and quads

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20801>
This commit is contained in:
Jesse Natalie
2023-01-19 13:25:53 -08:00
parent a318c101bb
commit 2c5d96bb58
6 changed files with 162 additions and 0 deletions

View File

@@ -347,6 +347,12 @@ enum dxil_signature_element_extended_properties {
DXIL_SIGNATURE_ELEMENT_USAGE_COMPONENT_MASK = 3,
};
enum dxil_quad_op_kind {
QUAD_READ_ACROSS_X = 0,
QUAD_READ_ACROSS_Y = 1,
QUAD_READ_ACROSS_DIAGONAL = 2,
};
#ifdef __cplusplus
extern "C" {
#endif

View File

@@ -97,6 +97,13 @@ static struct predefined_func_descr predefined_funcs[] = {
{"dx.op.waveGetLaneIndex", "i", "i", DXIL_ATTR_KIND_READ_NONE},
{"dx.op.waveGetLaneCount", "i", "i", DXIL_ATTR_KIND_READ_NONE},
{"dx.op.waveReadLaneFirst", "O", "iO", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.waveReadLaneAt", "O", "iOi", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.waveAnyTrue", "b", "ib", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.waveAllTrue", "b", "ib", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.waveActiveAllEqual", "b", "iO", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.waveActiveBallot", "F", "ib", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.quadReadLaneAt", "O", "iOi", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.quadOp", "O", "iOc", DXIL_ATTR_KIND_NO_UNWIND},
};
struct func_descr {
@@ -207,6 +214,7 @@ get_type_from_string(struct dxil_module *mod, const char *param_descr,
const struct dxil_type *target = get_type_from_string(mod, param_descr, overload, idx);
return dxil_module_get_pointer_type(mod, target);
}
case DXIL_FUNC_PARAM_FOURI32: return dxil_module_get_fouri32_type(mod);
default:
assert(0 && "unknown type identifier");
}

View File

@@ -44,6 +44,7 @@
#define DXIL_FUNC_PARAM_SAMPLE_POS 'S'
#define DXIL_FUNC_PARAM_RES_BIND '#'
#define DXIL_FUNC_PARAM_RES_PROPS 'P'
#define DXIL_FUNC_PARAM_FOURI32 'F'
#include "dxil_module.h"
#include "util/rb_tree.h"

View File

@@ -909,6 +909,16 @@ dxil_module_get_res_props_type(struct dxil_module *mod)
return dxil_module_get_struct_type(mod, "dx.types.ResourceProperties", fields, 2);
}
const struct dxil_type *
dxil_module_get_fouri32_type(struct dxil_module *mod)
{
/* %dx.types.fouri32 = type { i32, i32, i32, i32 } */
const struct dxil_type *int32_type = dxil_module_get_int_type(mod, 32);
const struct dxil_type *fields[4] = { int32_type, int32_type, int32_type, int32_type };
return dxil_module_get_struct_type(mod, "dx.types.fouri32", fields, 4);
}
const struct dxil_type *
dxil_module_add_function_type(struct dxil_module *m,
const struct dxil_type *ret_type,

View File

@@ -323,6 +323,9 @@ dxil_module_get_res_bind_type(struct dxil_module *m);
const struct dxil_type *
dxil_module_get_res_props_type(struct dxil_module *m);
const struct dxil_type *
dxil_module_get_fouri32_type(struct dxil_module *m);
const struct dxil_type *
dxil_module_get_struct_type(struct dxil_module *m,
const char *name,

View File

@@ -333,7 +333,14 @@ enum dxil_intr {
DXIL_INTR_WAVE_IS_FIRST_LANE = 110,
DXIL_INTR_WAVE_GET_LANE_INDEX = 111,
DXIL_INTR_WAVE_GET_LANE_COUNT = 112,
DXIL_INTR_WAVE_ANY_TRUE = 113,
DXIL_INTR_WAVE_ALL_TRUE = 114,
DXIL_INTR_WAVE_ACTIVE_ALL_EQUAL = 115,
DXIL_INTR_WAVE_ACTIVE_BALLOT = 116,
DXIL_INTR_WAVE_READ_LANE_AT = 117,
DXIL_INTR_WAVE_READ_LANE_FIRST = 118,
DXIL_INTR_QUAD_READ_LANE_AT = 122,
DXIL_INTR_QUAD_OP = 123,
DXIL_INTR_LEGACY_F32TOF16 = 130,
DXIL_INTR_LEGACY_F16TOF32 = 131,
@@ -4427,6 +4434,112 @@ emit_read_first_invocation(struct ntd_context *ctx, nir_intrinsic_instr *intr)
return true;
}
static bool
emit_read_invocation(struct ntd_context *ctx, nir_intrinsic_instr *intr)
{
ctx->mod.feats.wave_ops = 1;
bool quad = intr->intrinsic == nir_intrinsic_quad_broadcast;
const struct dxil_func *func = dxil_get_function(&ctx->mod, quad ? "dx.op.quadReadLaneAt" : "dx.op.waveReadLaneAt",
get_overload(nir_type_uint, intr->dest.ssa.bit_size));
const struct dxil_value *args[] = {
dxil_module_get_int32_const(&ctx->mod, quad ? DXIL_INTR_QUAD_READ_LANE_AT : DXIL_INTR_WAVE_READ_LANE_AT),
get_src(ctx, &intr->src[0], 0, nir_type_uint),
get_src(ctx, &intr->src[1], 0, nir_type_uint),
};
if (!func || !args[0] || !args[1] || !args[2])
return false;
const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
if (!ret)
return false;
store_dest_value(ctx, &intr->dest, 0, ret);
return true;
}
static bool
emit_vote_eq(struct ntd_context *ctx, nir_intrinsic_instr *intr)
{
ctx->mod.feats.wave_ops = 1;
nir_alu_type alu_type = intr->intrinsic == nir_intrinsic_vote_ieq ? nir_type_int : nir_type_float;
const struct dxil_func *func = dxil_get_function(&ctx->mod, "dx.op.waveActiveAllEqual",
get_overload(alu_type, intr->src[0].ssa->bit_size));
const struct dxil_value *args[] = {
dxil_module_get_int32_const(&ctx->mod, DXIL_INTR_WAVE_ACTIVE_ALL_EQUAL),
get_src(ctx, intr->src, 0, alu_type),
};
if (!func || !args[0] || !args[1])
return false;
const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
if (!ret)
return false;
store_dest_value(ctx, &intr->dest, 0, ret);
return true;
}
static bool
emit_vote(struct ntd_context *ctx, nir_intrinsic_instr *intr)
{
ctx->mod.feats.wave_ops = 1;
bool any = intr->intrinsic == nir_intrinsic_vote_any;
const struct dxil_func *func = dxil_get_function(&ctx->mod,
any ? "dx.op.waveAnyTrue" : "dx.op.waveAllTrue",
DXIL_NONE);
const struct dxil_value *args[] = {
dxil_module_get_int32_const(&ctx->mod, any ? DXIL_INTR_WAVE_ANY_TRUE : DXIL_INTR_WAVE_ALL_TRUE),
get_src(ctx, intr->src, 0, nir_type_bool),
};
if (!func || !args[0] || !args[1])
return false;
const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
if (!ret)
return false;
store_dest_value(ctx, &intr->dest, 0, ret);
return true;
}
static bool
emit_ballot(struct ntd_context *ctx, nir_intrinsic_instr *intr)
{
ctx->mod.feats.wave_ops = 1;
const struct dxil_func *func = dxil_get_function(&ctx->mod, "dx.op.waveActiveBallot", DXIL_NONE);
const struct dxil_value *args[] = {
dxil_module_get_int32_const(&ctx->mod, DXIL_INTR_WAVE_ACTIVE_BALLOT),
get_src(ctx, intr->src, 0, nir_type_bool),
};
if (!func || !args[0] || !args[1])
return false;
const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
if (!ret)
return false;
for (uint32_t i = 0; i < 4; ++i)
store_dest_value(ctx, &intr->dest, i, dxil_emit_extractval(&ctx->mod, ret, i));
return true;
}
static bool
emit_quad_op(struct ntd_context *ctx, nir_intrinsic_instr *intr, enum dxil_quad_op_kind op)
{
ctx->mod.feats.wave_ops = 1;
const struct dxil_func *func = dxil_get_function(&ctx->mod, "dx.op.quadOp",
get_overload(nir_type_uint, intr->dest.ssa.bit_size));
const struct dxil_value *args[] = {
dxil_module_get_int32_const(&ctx->mod, DXIL_INTR_QUAD_OP),
get_src(ctx, intr->src, 0, nir_type_uint),
dxil_module_get_int8_const(&ctx->mod, op),
};
if (!func || !args[0] || !args[1] || !args[2])
return false;
const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
if (!ret)
return false;
store_dest_value(ctx, &intr->dest, 0, ret);
return true;
}
static bool
emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
{
@@ -4634,8 +4747,29 @@ emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
return emit_load_unary_external_function(
ctx, intr, "dx.op.waveGetLaneIndex", DXIL_INTR_WAVE_GET_LANE_INDEX, DXIL_NONE);
case nir_intrinsic_vote_feq:
case nir_intrinsic_vote_ieq:
return emit_vote_eq(ctx, intr);
case nir_intrinsic_vote_any:
case nir_intrinsic_vote_all:
return emit_vote(ctx, intr);
case nir_intrinsic_ballot:
return emit_ballot(ctx, intr);
case nir_intrinsic_read_first_invocation:
return emit_read_first_invocation(ctx, intr);
case nir_intrinsic_read_invocation:
case nir_intrinsic_shuffle:
case nir_intrinsic_quad_broadcast:
return emit_read_invocation(ctx, intr);
case nir_intrinsic_quad_swap_horizontal:
return emit_quad_op(ctx, intr, QUAD_READ_ACROSS_X);
case nir_intrinsic_quad_swap_vertical:
return emit_quad_op(ctx, intr, QUAD_READ_ACROSS_Y);
case nir_intrinsic_quad_swap_diagonal:
return emit_quad_op(ctx, intr, QUAD_READ_ACROSS_DIAGONAL);
case nir_intrinsic_load_num_workgroups:
case nir_intrinsic_load_workgroup_size: