microsoft/compiler: Implement more wave/quad ops
This handles ballot, vote, shuffle, broadcast, and quads Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20801>
This commit is contained in:
@@ -347,6 +347,12 @@ enum dxil_signature_element_extended_properties {
|
||||
DXIL_SIGNATURE_ELEMENT_USAGE_COMPONENT_MASK = 3,
|
||||
};
|
||||
|
||||
enum dxil_quad_op_kind {
|
||||
QUAD_READ_ACROSS_X = 0,
|
||||
QUAD_READ_ACROSS_Y = 1,
|
||||
QUAD_READ_ACROSS_DIAGONAL = 2,
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
@@ -97,6 +97,13 @@ static struct predefined_func_descr predefined_funcs[] = {
|
||||
{"dx.op.waveGetLaneIndex", "i", "i", DXIL_ATTR_KIND_READ_NONE},
|
||||
{"dx.op.waveGetLaneCount", "i", "i", DXIL_ATTR_KIND_READ_NONE},
|
||||
{"dx.op.waveReadLaneFirst", "O", "iO", DXIL_ATTR_KIND_NO_UNWIND},
|
||||
{"dx.op.waveReadLaneAt", "O", "iOi", DXIL_ATTR_KIND_NO_UNWIND},
|
||||
{"dx.op.waveAnyTrue", "b", "ib", DXIL_ATTR_KIND_NO_UNWIND},
|
||||
{"dx.op.waveAllTrue", "b", "ib", DXIL_ATTR_KIND_NO_UNWIND},
|
||||
{"dx.op.waveActiveAllEqual", "b", "iO", DXIL_ATTR_KIND_NO_UNWIND},
|
||||
{"dx.op.waveActiveBallot", "F", "ib", DXIL_ATTR_KIND_NO_UNWIND},
|
||||
{"dx.op.quadReadLaneAt", "O", "iOi", DXIL_ATTR_KIND_NO_UNWIND},
|
||||
{"dx.op.quadOp", "O", "iOc", DXIL_ATTR_KIND_NO_UNWIND},
|
||||
};
|
||||
|
||||
struct func_descr {
|
||||
@@ -207,6 +214,7 @@ get_type_from_string(struct dxil_module *mod, const char *param_descr,
|
||||
const struct dxil_type *target = get_type_from_string(mod, param_descr, overload, idx);
|
||||
return dxil_module_get_pointer_type(mod, target);
|
||||
}
|
||||
case DXIL_FUNC_PARAM_FOURI32: return dxil_module_get_fouri32_type(mod);
|
||||
default:
|
||||
assert(0 && "unknown type identifier");
|
||||
}
|
||||
|
@@ -44,6 +44,7 @@
|
||||
#define DXIL_FUNC_PARAM_SAMPLE_POS 'S'
|
||||
#define DXIL_FUNC_PARAM_RES_BIND '#'
|
||||
#define DXIL_FUNC_PARAM_RES_PROPS 'P'
|
||||
#define DXIL_FUNC_PARAM_FOURI32 'F'
|
||||
|
||||
#include "dxil_module.h"
|
||||
#include "util/rb_tree.h"
|
||||
|
@@ -909,6 +909,16 @@ dxil_module_get_res_props_type(struct dxil_module *mod)
|
||||
return dxil_module_get_struct_type(mod, "dx.types.ResourceProperties", fields, 2);
|
||||
}
|
||||
|
||||
const struct dxil_type *
|
||||
dxil_module_get_fouri32_type(struct dxil_module *mod)
|
||||
{
|
||||
/* %dx.types.fouri32 = type { i32, i32, i32, i32 } */
|
||||
const struct dxil_type *int32_type = dxil_module_get_int_type(mod, 32);
|
||||
const struct dxil_type *fields[4] = { int32_type, int32_type, int32_type, int32_type };
|
||||
|
||||
return dxil_module_get_struct_type(mod, "dx.types.fouri32", fields, 4);
|
||||
}
|
||||
|
||||
const struct dxil_type *
|
||||
dxil_module_add_function_type(struct dxil_module *m,
|
||||
const struct dxil_type *ret_type,
|
||||
|
@@ -323,6 +323,9 @@ dxil_module_get_res_bind_type(struct dxil_module *m);
|
||||
const struct dxil_type *
|
||||
dxil_module_get_res_props_type(struct dxil_module *m);
|
||||
|
||||
const struct dxil_type *
|
||||
dxil_module_get_fouri32_type(struct dxil_module *m);
|
||||
|
||||
const struct dxil_type *
|
||||
dxil_module_get_struct_type(struct dxil_module *m,
|
||||
const char *name,
|
||||
|
@@ -333,7 +333,14 @@ enum dxil_intr {
|
||||
DXIL_INTR_WAVE_IS_FIRST_LANE = 110,
|
||||
DXIL_INTR_WAVE_GET_LANE_INDEX = 111,
|
||||
DXIL_INTR_WAVE_GET_LANE_COUNT = 112,
|
||||
DXIL_INTR_WAVE_ANY_TRUE = 113,
|
||||
DXIL_INTR_WAVE_ALL_TRUE = 114,
|
||||
DXIL_INTR_WAVE_ACTIVE_ALL_EQUAL = 115,
|
||||
DXIL_INTR_WAVE_ACTIVE_BALLOT = 116,
|
||||
DXIL_INTR_WAVE_READ_LANE_AT = 117,
|
||||
DXIL_INTR_WAVE_READ_LANE_FIRST = 118,
|
||||
DXIL_INTR_QUAD_READ_LANE_AT = 122,
|
||||
DXIL_INTR_QUAD_OP = 123,
|
||||
|
||||
DXIL_INTR_LEGACY_F32TOF16 = 130,
|
||||
DXIL_INTR_LEGACY_F16TOF32 = 131,
|
||||
@@ -4427,6 +4434,112 @@ emit_read_first_invocation(struct ntd_context *ctx, nir_intrinsic_instr *intr)
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
emit_read_invocation(struct ntd_context *ctx, nir_intrinsic_instr *intr)
|
||||
{
|
||||
ctx->mod.feats.wave_ops = 1;
|
||||
bool quad = intr->intrinsic == nir_intrinsic_quad_broadcast;
|
||||
const struct dxil_func *func = dxil_get_function(&ctx->mod, quad ? "dx.op.quadReadLaneAt" : "dx.op.waveReadLaneAt",
|
||||
get_overload(nir_type_uint, intr->dest.ssa.bit_size));
|
||||
const struct dxil_value *args[] = {
|
||||
dxil_module_get_int32_const(&ctx->mod, quad ? DXIL_INTR_QUAD_READ_LANE_AT : DXIL_INTR_WAVE_READ_LANE_AT),
|
||||
get_src(ctx, &intr->src[0], 0, nir_type_uint),
|
||||
get_src(ctx, &intr->src[1], 0, nir_type_uint),
|
||||
};
|
||||
if (!func || !args[0] || !args[1] || !args[2])
|
||||
return false;
|
||||
|
||||
const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
|
||||
if (!ret)
|
||||
return false;
|
||||
store_dest_value(ctx, &intr->dest, 0, ret);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
emit_vote_eq(struct ntd_context *ctx, nir_intrinsic_instr *intr)
|
||||
{
|
||||
ctx->mod.feats.wave_ops = 1;
|
||||
nir_alu_type alu_type = intr->intrinsic == nir_intrinsic_vote_ieq ? nir_type_int : nir_type_float;
|
||||
const struct dxil_func *func = dxil_get_function(&ctx->mod, "dx.op.waveActiveAllEqual",
|
||||
get_overload(alu_type, intr->src[0].ssa->bit_size));
|
||||
const struct dxil_value *args[] = {
|
||||
dxil_module_get_int32_const(&ctx->mod, DXIL_INTR_WAVE_ACTIVE_ALL_EQUAL),
|
||||
get_src(ctx, intr->src, 0, alu_type),
|
||||
};
|
||||
if (!func || !args[0] || !args[1])
|
||||
return false;
|
||||
|
||||
const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
|
||||
if (!ret)
|
||||
return false;
|
||||
store_dest_value(ctx, &intr->dest, 0, ret);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
emit_vote(struct ntd_context *ctx, nir_intrinsic_instr *intr)
|
||||
{
|
||||
ctx->mod.feats.wave_ops = 1;
|
||||
bool any = intr->intrinsic == nir_intrinsic_vote_any;
|
||||
const struct dxil_func *func = dxil_get_function(&ctx->mod,
|
||||
any ? "dx.op.waveAnyTrue" : "dx.op.waveAllTrue",
|
||||
DXIL_NONE);
|
||||
const struct dxil_value *args[] = {
|
||||
dxil_module_get_int32_const(&ctx->mod, any ? DXIL_INTR_WAVE_ANY_TRUE : DXIL_INTR_WAVE_ALL_TRUE),
|
||||
get_src(ctx, intr->src, 0, nir_type_bool),
|
||||
};
|
||||
if (!func || !args[0] || !args[1])
|
||||
return false;
|
||||
|
||||
const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
|
||||
if (!ret)
|
||||
return false;
|
||||
store_dest_value(ctx, &intr->dest, 0, ret);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
emit_ballot(struct ntd_context *ctx, nir_intrinsic_instr *intr)
|
||||
{
|
||||
ctx->mod.feats.wave_ops = 1;
|
||||
const struct dxil_func *func = dxil_get_function(&ctx->mod, "dx.op.waveActiveBallot", DXIL_NONE);
|
||||
const struct dxil_value *args[] = {
|
||||
dxil_module_get_int32_const(&ctx->mod, DXIL_INTR_WAVE_ACTIVE_BALLOT),
|
||||
get_src(ctx, intr->src, 0, nir_type_bool),
|
||||
};
|
||||
if (!func || !args[0] || !args[1])
|
||||
return false;
|
||||
|
||||
const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
|
||||
if (!ret)
|
||||
return false;
|
||||
for (uint32_t i = 0; i < 4; ++i)
|
||||
store_dest_value(ctx, &intr->dest, i, dxil_emit_extractval(&ctx->mod, ret, i));
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
emit_quad_op(struct ntd_context *ctx, nir_intrinsic_instr *intr, enum dxil_quad_op_kind op)
|
||||
{
|
||||
ctx->mod.feats.wave_ops = 1;
|
||||
const struct dxil_func *func = dxil_get_function(&ctx->mod, "dx.op.quadOp",
|
||||
get_overload(nir_type_uint, intr->dest.ssa.bit_size));
|
||||
const struct dxil_value *args[] = {
|
||||
dxil_module_get_int32_const(&ctx->mod, DXIL_INTR_QUAD_OP),
|
||||
get_src(ctx, intr->src, 0, nir_type_uint),
|
||||
dxil_module_get_int8_const(&ctx->mod, op),
|
||||
};
|
||||
if (!func || !args[0] || !args[1] || !args[2])
|
||||
return false;
|
||||
|
||||
const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
|
||||
if (!ret)
|
||||
return false;
|
||||
store_dest_value(ctx, &intr->dest, 0, ret);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
|
||||
{
|
||||
@@ -4634,8 +4747,29 @@ emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
|
||||
return emit_load_unary_external_function(
|
||||
ctx, intr, "dx.op.waveGetLaneIndex", DXIL_INTR_WAVE_GET_LANE_INDEX, DXIL_NONE);
|
||||
|
||||
case nir_intrinsic_vote_feq:
|
||||
case nir_intrinsic_vote_ieq:
|
||||
return emit_vote_eq(ctx, intr);
|
||||
case nir_intrinsic_vote_any:
|
||||
case nir_intrinsic_vote_all:
|
||||
return emit_vote(ctx, intr);
|
||||
|
||||
case nir_intrinsic_ballot:
|
||||
return emit_ballot(ctx, intr);
|
||||
|
||||
case nir_intrinsic_read_first_invocation:
|
||||
return emit_read_first_invocation(ctx, intr);
|
||||
case nir_intrinsic_read_invocation:
|
||||
case nir_intrinsic_shuffle:
|
||||
case nir_intrinsic_quad_broadcast:
|
||||
return emit_read_invocation(ctx, intr);
|
||||
|
||||
case nir_intrinsic_quad_swap_horizontal:
|
||||
return emit_quad_op(ctx, intr, QUAD_READ_ACROSS_X);
|
||||
case nir_intrinsic_quad_swap_vertical:
|
||||
return emit_quad_op(ctx, intr, QUAD_READ_ACROSS_Y);
|
||||
case nir_intrinsic_quad_swap_diagonal:
|
||||
return emit_quad_op(ctx, intr, QUAD_READ_ACROSS_DIAGONAL);
|
||||
|
||||
case nir_intrinsic_load_num_workgroups:
|
||||
case nir_intrinsic_load_workgroup_size:
|
||||
|
Reference in New Issue
Block a user