From 2c5d96bb58e302913ef7682d292fe33a10c61ebe Mon Sep 17 00:00:00 2001 From: Jesse Natalie Date: Thu, 19 Jan 2023 13:25:53 -0800 Subject: [PATCH] microsoft/compiler: Implement more wave/quad ops This handles ballot, vote, shuffle, broadcast, and quads Part-of: --- src/microsoft/compiler/dxil_enums.h | 6 ++ src/microsoft/compiler/dxil_function.c | 8 ++ src/microsoft/compiler/dxil_function.h | 1 + src/microsoft/compiler/dxil_module.c | 10 ++ src/microsoft/compiler/dxil_module.h | 3 + src/microsoft/compiler/nir_to_dxil.c | 134 +++++++++++++++++++++++++ 6 files changed, 162 insertions(+) diff --git a/src/microsoft/compiler/dxil_enums.h b/src/microsoft/compiler/dxil_enums.h index 7241de189d9..d8181dea56a 100644 --- a/src/microsoft/compiler/dxil_enums.h +++ b/src/microsoft/compiler/dxil_enums.h @@ -347,6 +347,12 @@ enum dxil_signature_element_extended_properties { DXIL_SIGNATURE_ELEMENT_USAGE_COMPONENT_MASK = 3, }; +enum dxil_quad_op_kind { + QUAD_READ_ACROSS_X = 0, + QUAD_READ_ACROSS_Y = 1, + QUAD_READ_ACROSS_DIAGONAL = 2, +}; + #ifdef __cplusplus extern "C" { #endif diff --git a/src/microsoft/compiler/dxil_function.c b/src/microsoft/compiler/dxil_function.c index 4fdc8e3f162..b79b96973be 100644 --- a/src/microsoft/compiler/dxil_function.c +++ b/src/microsoft/compiler/dxil_function.c @@ -97,6 +97,13 @@ static struct predefined_func_descr predefined_funcs[] = { {"dx.op.waveGetLaneIndex", "i", "i", DXIL_ATTR_KIND_READ_NONE}, {"dx.op.waveGetLaneCount", "i", "i", DXIL_ATTR_KIND_READ_NONE}, {"dx.op.waveReadLaneFirst", "O", "iO", DXIL_ATTR_KIND_NO_UNWIND}, +{"dx.op.waveReadLaneAt", "O", "iOi", DXIL_ATTR_KIND_NO_UNWIND}, +{"dx.op.waveAnyTrue", "b", "ib", DXIL_ATTR_KIND_NO_UNWIND}, +{"dx.op.waveAllTrue", "b", "ib", DXIL_ATTR_KIND_NO_UNWIND}, +{"dx.op.waveActiveAllEqual", "b", "iO", DXIL_ATTR_KIND_NO_UNWIND}, +{"dx.op.waveActiveBallot", "F", "ib", DXIL_ATTR_KIND_NO_UNWIND}, +{"dx.op.quadReadLaneAt", "O", "iOi", DXIL_ATTR_KIND_NO_UNWIND}, +{"dx.op.quadOp", "O", "iOc", DXIL_ATTR_KIND_NO_UNWIND}, }; struct func_descr { @@ -207,6 +214,7 @@ get_type_from_string(struct dxil_module *mod, const char *param_descr, const struct dxil_type *target = get_type_from_string(mod, param_descr, overload, idx); return dxil_module_get_pointer_type(mod, target); } + case DXIL_FUNC_PARAM_FOURI32: return dxil_module_get_fouri32_type(mod); default: assert(0 && "unknown type identifier"); } diff --git a/src/microsoft/compiler/dxil_function.h b/src/microsoft/compiler/dxil_function.h index c8950e18d88..67684cded02 100644 --- a/src/microsoft/compiler/dxil_function.h +++ b/src/microsoft/compiler/dxil_function.h @@ -44,6 +44,7 @@ #define DXIL_FUNC_PARAM_SAMPLE_POS 'S' #define DXIL_FUNC_PARAM_RES_BIND '#' #define DXIL_FUNC_PARAM_RES_PROPS 'P' +#define DXIL_FUNC_PARAM_FOURI32 'F' #include "dxil_module.h" #include "util/rb_tree.h" diff --git a/src/microsoft/compiler/dxil_module.c b/src/microsoft/compiler/dxil_module.c index 89cdfac1f30..5272459b145 100644 --- a/src/microsoft/compiler/dxil_module.c +++ b/src/microsoft/compiler/dxil_module.c @@ -909,6 +909,16 @@ dxil_module_get_res_props_type(struct dxil_module *mod) return dxil_module_get_struct_type(mod, "dx.types.ResourceProperties", fields, 2); } +const struct dxil_type * +dxil_module_get_fouri32_type(struct dxil_module *mod) +{ + /* %dx.types.fouri32 = type { i32, i32, i32, i32 } */ + const struct dxil_type *int32_type = dxil_module_get_int_type(mod, 32); + const struct dxil_type *fields[4] = { int32_type, int32_type, int32_type, int32_type }; + + return dxil_module_get_struct_type(mod, "dx.types.fouri32", fields, 4); +} + const struct dxil_type * dxil_module_add_function_type(struct dxil_module *m, const struct dxil_type *ret_type, diff --git a/src/microsoft/compiler/dxil_module.h b/src/microsoft/compiler/dxil_module.h index 8669a7cc6f6..9ab8d3aeaca 100644 --- a/src/microsoft/compiler/dxil_module.h +++ b/src/microsoft/compiler/dxil_module.h @@ -323,6 +323,9 @@ dxil_module_get_res_bind_type(struct dxil_module *m); const struct dxil_type * dxil_module_get_res_props_type(struct dxil_module *m); +const struct dxil_type * +dxil_module_get_fouri32_type(struct dxil_module *m); + const struct dxil_type * dxil_module_get_struct_type(struct dxil_module *m, const char *name, diff --git a/src/microsoft/compiler/nir_to_dxil.c b/src/microsoft/compiler/nir_to_dxil.c index 899462e726a..96ad74aed42 100644 --- a/src/microsoft/compiler/nir_to_dxil.c +++ b/src/microsoft/compiler/nir_to_dxil.c @@ -333,7 +333,14 @@ enum dxil_intr { DXIL_INTR_WAVE_IS_FIRST_LANE = 110, DXIL_INTR_WAVE_GET_LANE_INDEX = 111, DXIL_INTR_WAVE_GET_LANE_COUNT = 112, + DXIL_INTR_WAVE_ANY_TRUE = 113, + DXIL_INTR_WAVE_ALL_TRUE = 114, + DXIL_INTR_WAVE_ACTIVE_ALL_EQUAL = 115, + DXIL_INTR_WAVE_ACTIVE_BALLOT = 116, + DXIL_INTR_WAVE_READ_LANE_AT = 117, DXIL_INTR_WAVE_READ_LANE_FIRST = 118, + DXIL_INTR_QUAD_READ_LANE_AT = 122, + DXIL_INTR_QUAD_OP = 123, DXIL_INTR_LEGACY_F32TOF16 = 130, DXIL_INTR_LEGACY_F16TOF32 = 131, @@ -4427,6 +4434,112 @@ emit_read_first_invocation(struct ntd_context *ctx, nir_intrinsic_instr *intr) return true; } +static bool +emit_read_invocation(struct ntd_context *ctx, nir_intrinsic_instr *intr) +{ + ctx->mod.feats.wave_ops = 1; + bool quad = intr->intrinsic == nir_intrinsic_quad_broadcast; + const struct dxil_func *func = dxil_get_function(&ctx->mod, quad ? "dx.op.quadReadLaneAt" : "dx.op.waveReadLaneAt", + get_overload(nir_type_uint, intr->dest.ssa.bit_size)); + const struct dxil_value *args[] = { + dxil_module_get_int32_const(&ctx->mod, quad ? DXIL_INTR_QUAD_READ_LANE_AT : DXIL_INTR_WAVE_READ_LANE_AT), + get_src(ctx, &intr->src[0], 0, nir_type_uint), + get_src(ctx, &intr->src[1], 0, nir_type_uint), + }; + if (!func || !args[0] || !args[1] || !args[2]) + return false; + + const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args)); + if (!ret) + return false; + store_dest_value(ctx, &intr->dest, 0, ret); + return true; +} + +static bool +emit_vote_eq(struct ntd_context *ctx, nir_intrinsic_instr *intr) +{ + ctx->mod.feats.wave_ops = 1; + nir_alu_type alu_type = intr->intrinsic == nir_intrinsic_vote_ieq ? nir_type_int : nir_type_float; + const struct dxil_func *func = dxil_get_function(&ctx->mod, "dx.op.waveActiveAllEqual", + get_overload(alu_type, intr->src[0].ssa->bit_size)); + const struct dxil_value *args[] = { + dxil_module_get_int32_const(&ctx->mod, DXIL_INTR_WAVE_ACTIVE_ALL_EQUAL), + get_src(ctx, intr->src, 0, alu_type), + }; + if (!func || !args[0] || !args[1]) + return false; + + const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args)); + if (!ret) + return false; + store_dest_value(ctx, &intr->dest, 0, ret); + return true; +} + +static bool +emit_vote(struct ntd_context *ctx, nir_intrinsic_instr *intr) +{ + ctx->mod.feats.wave_ops = 1; + bool any = intr->intrinsic == nir_intrinsic_vote_any; + const struct dxil_func *func = dxil_get_function(&ctx->mod, + any ? "dx.op.waveAnyTrue" : "dx.op.waveAllTrue", + DXIL_NONE); + const struct dxil_value *args[] = { + dxil_module_get_int32_const(&ctx->mod, any ? DXIL_INTR_WAVE_ANY_TRUE : DXIL_INTR_WAVE_ALL_TRUE), + get_src(ctx, intr->src, 0, nir_type_bool), + }; + if (!func || !args[0] || !args[1]) + return false; + + const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args)); + if (!ret) + return false; + store_dest_value(ctx, &intr->dest, 0, ret); + return true; +} + +static bool +emit_ballot(struct ntd_context *ctx, nir_intrinsic_instr *intr) +{ + ctx->mod.feats.wave_ops = 1; + const struct dxil_func *func = dxil_get_function(&ctx->mod, "dx.op.waveActiveBallot", DXIL_NONE); + const struct dxil_value *args[] = { + dxil_module_get_int32_const(&ctx->mod, DXIL_INTR_WAVE_ACTIVE_BALLOT), + get_src(ctx, intr->src, 0, nir_type_bool), + }; + if (!func || !args[0] || !args[1]) + return false; + + const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args)); + if (!ret) + return false; + for (uint32_t i = 0; i < 4; ++i) + store_dest_value(ctx, &intr->dest, i, dxil_emit_extractval(&ctx->mod, ret, i)); + return true; +} + +static bool +emit_quad_op(struct ntd_context *ctx, nir_intrinsic_instr *intr, enum dxil_quad_op_kind op) +{ + ctx->mod.feats.wave_ops = 1; + const struct dxil_func *func = dxil_get_function(&ctx->mod, "dx.op.quadOp", + get_overload(nir_type_uint, intr->dest.ssa.bit_size)); + const struct dxil_value *args[] = { + dxil_module_get_int32_const(&ctx->mod, DXIL_INTR_QUAD_OP), + get_src(ctx, intr->src, 0, nir_type_uint), + dxil_module_get_int8_const(&ctx->mod, op), + }; + if (!func || !args[0] || !args[1] || !args[2]) + return false; + + const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args)); + if (!ret) + return false; + store_dest_value(ctx, &intr->dest, 0, ret); + return true; +} + static bool emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr) { @@ -4634,8 +4747,29 @@ emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr) return emit_load_unary_external_function( ctx, intr, "dx.op.waveGetLaneIndex", DXIL_INTR_WAVE_GET_LANE_INDEX, DXIL_NONE); + case nir_intrinsic_vote_feq: + case nir_intrinsic_vote_ieq: + return emit_vote_eq(ctx, intr); + case nir_intrinsic_vote_any: + case nir_intrinsic_vote_all: + return emit_vote(ctx, intr); + + case nir_intrinsic_ballot: + return emit_ballot(ctx, intr); + case nir_intrinsic_read_first_invocation: return emit_read_first_invocation(ctx, intr); + case nir_intrinsic_read_invocation: + case nir_intrinsic_shuffle: + case nir_intrinsic_quad_broadcast: + return emit_read_invocation(ctx, intr); + + case nir_intrinsic_quad_swap_horizontal: + return emit_quad_op(ctx, intr, QUAD_READ_ACROSS_X); + case nir_intrinsic_quad_swap_vertical: + return emit_quad_op(ctx, intr, QUAD_READ_ACROSS_Y); + case nir_intrinsic_quad_swap_diagonal: + return emit_quad_op(ctx, intr, QUAD_READ_ACROSS_DIAGONAL); case nir_intrinsic_load_num_workgroups: case nir_intrinsic_load_workgroup_size: