diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index 8e655a2f0ea..72d793480e7 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -638,6 +638,7 @@ radv_shader_compile_to_nir(struct radv_device *device, struct vk_shader_module * nir_lower_subgroups(nir, &(struct nir_lower_subgroups_options){ .subgroup_size = subgroup_size, .ballot_bit_size = ballot_bit_size, + .ballot_components = 1, .lower_to_scalar = 1, .lower_subgroup_masks = 1, .lower_shuffle = 1, diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 1471623cf53..bf2c9ff138b 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -4726,6 +4726,7 @@ bool nir_lower_is_helper_invocation(nir_shader *shader); typedef struct nir_lower_subgroups_options { uint8_t subgroup_size; uint8_t ballot_bit_size; + uint8_t ballot_components; bool lower_to_scalar:1; bool lower_vote_trivial:1; bool lower_vote_eq:1; diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h index 64832594816..2551da45f9f 100644 --- a/src/compiler/nir/nir_builder.h +++ b/src/compiler/nir/nir_builder.h @@ -1140,6 +1140,30 @@ nir_pad_vector(nir_builder *b, nir_ssa_def *src, unsigned num_components) return nir_vec(b, components, num_components); } +/** + * Pad a value to N components with copies of the given immediate of matching + * bit size. If the value already contains >= num_components, it is returned + * without change. + */ +static inline nir_ssa_def * +nir_pad_vector_imm_int(nir_builder *b, nir_ssa_def *src, uint64_t imm_val, + unsigned num_components) +{ + assert(src->num_components <= num_components); + if (src->num_components == num_components) + return src; + + nir_ssa_def *components[NIR_MAX_VEC_COMPONENTS]; + nir_ssa_def *imm = nir_imm_intN_t(b, imm_val, src->bit_size); + unsigned i = 0; + for (; i < src->num_components; i++) + components[i] = nir_channel(b, src, i); + for (; i < num_components; i++) + components[i] = imm; + + return nir_vec(b, components, num_components); +} + /** * Pad a value to 4 components with undefs of matching bit size. * If the value already contains >= 4 components, it is returned without change. diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index 3f0126dbada..cfe74758ff1 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -23,6 +23,7 @@ #include "nir.h" #include "nir_builder.h" +#include "util/u_math.h" /** * \file nir_opt_intrinsics.c @@ -61,54 +62,24 @@ lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin) } static nir_ssa_def * -ballot_type_to_uint(nir_builder *b, nir_ssa_def *value, unsigned bit_size) +ballot_type_to_uint(nir_builder *b, nir_ssa_def *value, + const nir_lower_subgroups_options *options) { - /* We only use this on uvec4 types */ + /* Only the new-style SPIR-V subgroup instructions take a ballot result as + * an argument, so we only use this on uvec4 types. + */ assert(value->num_components == 4 && value->bit_size == 32); - if (bit_size == 32) { - return nir_channel(b, value, 0); - } else { - assert(bit_size == 64); - return nir_pack_64_2x32_split(b, nir_channel(b, value, 0), - nir_channel(b, value, 1)); - } + return nir_extract_bits(b, &value, 1, 0, options->ballot_components, + options->ballot_bit_size); } -/* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */ static nir_ssa_def * uint_to_ballot_type(nir_builder *b, nir_ssa_def *value, unsigned num_components, unsigned bit_size) { - assert(value->num_components == 1); - assert(value->bit_size == 32 || value->bit_size == 64); - - nir_ssa_def *zero = nir_imm_int(b, 0); - if (num_components > 1) { - /* SPIR-V uses a uvec4 for ballot values */ - assert(num_components == 4); - assert(bit_size == 32); - - if (value->bit_size == 32) { - return nir_vec4(b, value, zero, zero, zero); - } else { - assert(value->bit_size == 64); - return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value), - nir_unpack_64_2x32_split_y(b, value), - zero, zero); - } - } else { - /* GLSL uses a uint64_t for ballot values */ - assert(num_components == 1); - assert(bit_size == 64); - - if (value->bit_size == 32) { - return nir_pack_64_2x32_split(b, value, zero); - } else { - assert(value->bit_size == 64); - return value; - } - } + value = nir_bitcast_vector(b, value, bit_size); + return nir_pad_vector_imm_int(b, value, 0, num_components); } static nir_ssa_def * @@ -317,13 +288,169 @@ lower_subgroups_filter(const nir_instr *instr, const void *_options) return instr->type == nir_instr_type_intrinsic; } +/* Return a ballot-mask-sized value which represents "val" sign-extended and + * then shifted left by "shift". Only particular values for "val" are + * supported, see below. + */ static nir_ssa_def * -build_subgroup_mask(nir_builder *b, unsigned bit_size, +build_ballot_imm_ishl(nir_builder *b, int64_t val, nir_ssa_def *shift, + const nir_lower_subgroups_options *options) +{ + /* This only works if all the high bits are the same as bit 1. */ + assert(((val << 62) >> 62) == val); + + /* First compute the result assuming one ballot component. */ + nir_ssa_def *result = + nir_ishl(b, nir_imm_intN_t(b, val, options->ballot_bit_size), shift); + + if (options->ballot_components == 1) + return result; + + /* Fix up the result when there is > 1 component. The idea is that nir_ishl + * masks out the high bits of the shift value already, so in case there's + * more than one component the component which 1 would be shifted into + * already has the right value and all we have to do is fixup the other + * components. Components below it should always be 0, and components above + * it must be either 0 or ~0 because of the assert above. For example, if + * the target ballot size is 2 x uint32, and we're shifting 1 by 33, then + * we'll feed 33 into ishl, which will mask it off to get 1, so we'll + * compute a single-component result of 2, which is correct for the second + * component, but the first component needs to be 0, which we get by + * comparing the high bits of the shift with 0 and selecting the original + * answer or 0 for the first component (and something similar with the + * second component). This idea is generalized here for any component count + */ + nir_const_value min_shift[4] = { 0 }; + for (unsigned i = 0; i < options->ballot_components; i++) + min_shift[i].i32 = i * options->ballot_bit_size; + nir_ssa_def *min_shift_val = nir_build_imm(b, options->ballot_components, 32, min_shift); + + nir_const_value max_shift[4] = { 0 }; + for (unsigned i = 0; i < options->ballot_components; i++) + max_shift[i].i32 = (i + 1) * options->ballot_bit_size; + nir_ssa_def *max_shift_val = nir_build_imm(b, options->ballot_components, 32, max_shift); + + return nir_bcsel(b, nir_ult(b, shift, max_shift_val), + nir_bcsel(b, nir_ult(b, shift, min_shift_val), + nir_imm_intN_t(b, val >> 63, result->bit_size), + result), + nir_imm_intN_t(b, 0, result->bit_size)); +} + +static nir_ssa_def * +build_subgroup_eq_mask(nir_builder *b, + const nir_lower_subgroups_options *options) +{ + nir_ssa_def *subgroup_idx = nir_load_subgroup_invocation(b); + + return build_ballot_imm_ishl(b, 1, subgroup_idx, options); +} + +static nir_ssa_def * +build_subgroup_ge_mask(nir_builder *b, + const nir_lower_subgroups_options *options) +{ + nir_ssa_def *subgroup_idx = nir_load_subgroup_invocation(b); + + return build_ballot_imm_ishl(b, ~0ull, subgroup_idx, options); +} + +static nir_ssa_def * +build_subgroup_gt_mask(nir_builder *b, + const nir_lower_subgroups_options *options) +{ + nir_ssa_def *subgroup_idx = nir_load_subgroup_invocation(b); + + return build_ballot_imm_ishl(b, ~1ull, subgroup_idx, options); +} + +/* Return a mask which is 1 for threads up to the run-time subgroup size, i.e. + * 1 for the entire subgroup. SPIR-V requires us to return 0 for indices at or + * above the subgroup size for the masks, but gt_mask and ge_mask make them 1 + * so we have to "and" with this mask. + */ +static nir_ssa_def * +build_subgroup_mask(nir_builder *b, const nir_lower_subgroups_options *options) { - return nir_ushr(b, nir_imm_intN_t(b, ~0ull, bit_size), - nir_isub(b, nir_imm_int(b, bit_size), - nir_load_subgroup_size(b))); + nir_ssa_def *subgroup_size = nir_load_subgroup_size(b); + + /* First compute the result assuming one ballot component. */ + nir_ssa_def *result = + nir_ushr(b, nir_imm_intN_t(b, ~0ull, options->ballot_bit_size), + nir_isub_imm(b, options->ballot_bit_size, + subgroup_size)); + + /* Since the subgroup size and ballot bitsize are both powers of two, there + * are two possible cases to consider: + * + * (1) The subgroup size is less than the ballot bitsize. We need to return + * "result" in the first component and 0 in every other component. + * (2) The subgroup size is a multiple of the ballot bitsize. We need to + * return ~0 if the subgroup size divided by the ballot bitsize is less + * than or equal to the index in the vector and 0 otherwise. For example, + * with a target ballot type of 4 x uint32 and subgroup_size = 64 we'd need + * to return { ~0, ~0, 0, 0 }. + * + * In case (2) it turns out that "result" will be ~0, because + * "ballot_bit_size - subgroup_size" is also a multiple of + * "ballot_bit_size" and since nir_ushr masks the shift value it will + * shifted by 0. This means that the first component can just be "result" + * in all cases. The other components will also get the correct value in + * case (1) if we just use the rule in case (2), so we'll get the correct + * result if we just follow (2) and then replace the first component with + * "result". + */ + nir_const_value min_idx[4] = { 0 }; + for (unsigned i = 0; i < options->ballot_components; i++) + min_idx[i].i32 = i * options->ballot_bit_size; + nir_ssa_def *min_idx_val = nir_build_imm(b, options->ballot_components, 32, min_idx); + + nir_ssa_def *result_extended = + nir_pad_vector_imm_int(b, result, ~0ull, options->ballot_components); + + return nir_bcsel(b, nir_ult(b, min_idx_val, subgroup_size), + result_extended, nir_imm_intN_t(b, 0, options->ballot_bit_size)); +} + +static nir_ssa_def * +vec_bit_count(nir_builder *b, nir_ssa_def *value) +{ + nir_ssa_def *vec_result = nir_bit_count(b, value); + nir_ssa_def *result = nir_channel(b, vec_result, 0); + for (unsigned i = 1; i < value->num_components; i++) + result = nir_iadd(b, result, nir_channel(b, vec_result, i)); + return result; +} + +static nir_ssa_def * +vec_find_lsb(nir_builder *b, nir_ssa_def *value) +{ + nir_ssa_def *vec_result = nir_find_lsb(b, value); + nir_ssa_def *result = nir_imm_int(b, -1); + for (int i = value->num_components - 1; i >= 0; i--) { + nir_ssa_def *channel = nir_channel(b, vec_result, i); + /* result = channel >= 0 ? (i * bitsize + channel) : result */ + result = nir_bcsel(b, nir_ige(b, channel, nir_imm_int(b, 0)), + nir_iadd_imm(b, channel, i * value->bit_size), + result); + } + return result; +} + +static nir_ssa_def * +vec_find_msb(nir_builder *b, nir_ssa_def *value) +{ + nir_ssa_def *vec_result = nir_ufind_msb(b, value); + nir_ssa_def *result = nir_imm_int(b, -1); + for (unsigned i = 0; i < value->num_components; i++) { + nir_ssa_def *channel = nir_channel(b, vec_result, i); + /* result = channel >= 0 ? (i * bitsize + channel) : result */ + result = nir_bcsel(b, nir_ige(b, channel, nir_imm_int(b, 0)), + nir_iadd_imm(b, channel, i * value->bit_size), + result); + } + return result; } static nir_ssa_def * @@ -410,32 +537,24 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) if (!options->lower_subgroup_masks) return NULL; - /* If either the result or the requested bit size is 64-bits then we - * know that we have 64-bit types and using them will probably be more - * efficient than messing around with 32-bit shifts and packing. - */ - const unsigned bit_size = MAX2(options->ballot_bit_size, - intrin->dest.ssa.bit_size); - - nir_ssa_def *count = nir_load_subgroup_invocation(b); nir_ssa_def *val; switch (intrin->intrinsic) { case nir_intrinsic_load_subgroup_eq_mask: - val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count); + val = build_subgroup_eq_mask(b, options); break; case nir_intrinsic_load_subgroup_ge_mask: - val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count), - build_subgroup_mask(b, bit_size, options)); + val = nir_iand(b, build_subgroup_ge_mask(b, options), + build_subgroup_mask(b, options)); break; case nir_intrinsic_load_subgroup_gt_mask: - val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count), - build_subgroup_mask(b, bit_size, options)); + val = nir_iand(b, build_subgroup_gt_mask(b, options), + build_subgroup_mask(b, options)); break; case nir_intrinsic_load_subgroup_le_mask: - val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count)); + val = nir_inot(b, build_subgroup_gt_mask(b, options)); break; case nir_intrinsic_load_subgroup_lt_mask: - val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count)); + val = nir_inot(b, build_subgroup_ge_mask(b, options)); break; default: unreachable("you seriously can't tell this is unreachable?"); @@ -447,11 +566,13 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) } case nir_intrinsic_ballot: { - if (intrin->dest.ssa.num_components == 1 && + if (intrin->dest.ssa.num_components == options->ballot_components && intrin->dest.ssa.bit_size == options->ballot_bit_size) return NULL; - nir_ssa_def *ballot = nir_ballot(b, 1, options->ballot_bit_size, intrin->src[0].ssa); + nir_ssa_def *ballot = + nir_ballot(b, options->ballot_components, options->ballot_bit_size, + intrin->src[0].ssa); return uint_to_ballot_type(b, ballot, intrin->dest.ssa.num_components, @@ -464,7 +585,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) case nir_intrinsic_ballot_find_msb: { assert(intrin->src[0].is_ssa); nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa, - options->ballot_bit_size); + options); if (intrin->intrinsic != nir_intrinsic_ballot_bitfield_extract && intrin->intrinsic != nir_intrinsic_ballot_find_lsb) { @@ -487,22 +608,31 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) * valid bits, we hit the undefined results case and we can return * anything we want. */ - int_val = nir_iand(b, int_val, - build_subgroup_mask(b, options->ballot_bit_size, options)); + int_val = nir_iand(b, int_val, build_subgroup_mask(b, options)); } switch (intrin->intrinsic) { - case nir_intrinsic_ballot_bitfield_extract: + case nir_intrinsic_ballot_bitfield_extract: { assert(intrin->src[1].is_ssa); - return nir_i2b(b, nir_iand(b, nir_ushr(b, int_val, - intrin->src[1].ssa), - nir_imm_intN_t(b, 1, options->ballot_bit_size))); + nir_ssa_def *idx = intrin->src[1].ssa; + if (int_val->num_components > 1) { + /* idx will be truncated by nir_ushr, so we just need to select + * the right component using the bits of idx that are truncated in + * the shift. + */ + int_val = + nir_vector_extract(b, int_val, + nir_udiv_imm(b, idx, int_val->bit_size)); + } + + return nir_i2b(b, nir_iand_imm(b, nir_ushr(b, int_val, idx), 1)); + } case nir_intrinsic_ballot_bit_count_reduce: - return nir_bit_count(b, int_val); + return vec_bit_count(b, int_val); case nir_intrinsic_ballot_find_lsb: - return nir_find_lsb(b, int_val); + return vec_find_lsb(b, int_val); case nir_intrinsic_ballot_find_msb: - return nir_ufind_msb(b, int_val); + return vec_find_msb(b, int_val); default: unreachable("you seriously can't tell this is unreachable?"); } @@ -510,20 +640,18 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) case nir_intrinsic_ballot_bit_count_exclusive: case nir_intrinsic_ballot_bit_count_inclusive: { - nir_ssa_def *count = nir_load_subgroup_invocation(b); - nir_ssa_def *mask = nir_imm_intN_t(b, ~0ull, options->ballot_bit_size); + nir_ssa_def *mask; if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) { - const unsigned bits = options->ballot_bit_size; - mask = nir_ushr(b, mask, nir_isub(b, nir_imm_int(b, bits - 1), count)); + mask = nir_inot(b, build_subgroup_gt_mask(b, options)); } else { - mask = nir_inot(b, nir_ishl(b, mask, count)); + mask = nir_inot(b, build_subgroup_ge_mask(b, options)); } assert(intrin->src[0].is_ssa); nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa, - options->ballot_bit_size); + options); - return nir_bit_count(b, nir_iand(b, int_val, mask)); + return vec_bit_count(b, nir_iand(b, int_val, mask)); } case nir_intrinsic_elect: { diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir.c b/src/gallium/auxiliary/gallivm/lp_bld_nir.c index d932c9f11f7..420cac99e0f 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir.c @@ -2354,6 +2354,7 @@ void lp_build_opt_nir(struct nir_shader *nir) const nir_lower_subgroups_options subgroups_options = { .subgroup_size = lp_native_vector_width / 32, .ballot_bit_size = 32, + .ballot_components = 1, .lower_to_scalar = true, .lower_subgroup_masks = true, }; diff --git a/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp b/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp index e12fee25529..59c0ed89eb6 100644 --- a/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp +++ b/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp @@ -3115,6 +3115,7 @@ Converter::run() struct nir_lower_subgroups_options subgroup_options = {}; subgroup_options.subgroup_size = 32; subgroup_options.ballot_bit_size = 32; + subgroup_options.ballot_components = 1; subgroup_options.lower_elect = true; /* prepare for IO lowering */ diff --git a/src/gallium/drivers/radeonsi/si_shader_nir.c b/src/gallium/drivers/radeonsi/si_shader_nir.c index ae1192d97ed..5c972a368a3 100644 --- a/src/gallium/drivers/radeonsi/si_shader_nir.c +++ b/src/gallium/drivers/radeonsi/si_shader_nir.c @@ -821,6 +821,7 @@ static void si_lower_nir(struct si_screen *sscreen, struct nir_shader *nir) const nir_lower_subgroups_options subgroups_options = { .subgroup_size = 64, .ballot_bit_size = 64, + .ballot_components = 1, .lower_to_scalar = true, .lower_subgroup_masks = true, .lower_vote_trivial = false, diff --git a/src/intel/compiler/brw_nir.c b/src/intel/compiler/brw_nir.c index ea9ce6df643..d27af2c0ec2 100644 --- a/src/intel/compiler/brw_nir.c +++ b/src/intel/compiler/brw_nir.c @@ -874,6 +874,7 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir, const nir_lower_subgroups_options subgroups_options = { .ballot_bit_size = 32, + .ballot_components = 1, .lower_to_scalar = true, .lower_vote_trivial = !is_scalar, .lower_shuffle = true, @@ -1354,6 +1355,7 @@ brw_nir_apply_key(nir_shader *nir, .subgroup_size = get_subgroup_size(nir->info.stage, key, max_subgroup_size), .ballot_bit_size = 32, + .ballot_components = 1, .lower_subgroup_masks = true, }; OPT(nir_lower_subgroups, &subgroups_options);