nir/subgroups: Support > 1 ballot components

Qualcomm has a mode with a subgroup size of 128, so just emitting larger
integer operations and then lowering them later isn't an option. This
makes the pass able to handle the lowering itself, so that we don't have
to go down to 64-thread wavefronts when ballots are used.

(The GLSL and legacy SPIR-V extensions only support a maximum of 64
threads, but I guess we'll cross that bridge when we come to it...)

Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6752>
This commit is contained in:
Connor Abbott
2020-09-10 18:48:04 +02:00
committed by Marge Bot
parent 90819b9b0e
commit e4e79de2a4
8 changed files with 236 additions and 77 deletions

View File

@@ -638,6 +638,7 @@ radv_shader_compile_to_nir(struct radv_device *device, struct vk_shader_module *
nir_lower_subgroups(nir, &(struct nir_lower_subgroups_options){
.subgroup_size = subgroup_size,
.ballot_bit_size = ballot_bit_size,
.ballot_components = 1,
.lower_to_scalar = 1,
.lower_subgroup_masks = 1,
.lower_shuffle = 1,

View File

@@ -4726,6 +4726,7 @@ bool nir_lower_is_helper_invocation(nir_shader *shader);
typedef struct nir_lower_subgroups_options {
uint8_t subgroup_size;
uint8_t ballot_bit_size;
uint8_t ballot_components;
bool lower_to_scalar:1;
bool lower_vote_trivial:1;
bool lower_vote_eq:1;

View File

@@ -1140,6 +1140,30 @@ nir_pad_vector(nir_builder *b, nir_ssa_def *src, unsigned num_components)
return nir_vec(b, components, num_components);
}
/**
* Pad a value to N components with copies of the given immediate of matching
* bit size. If the value already contains >= num_components, it is returned
* without change.
*/
static inline nir_ssa_def *
nir_pad_vector_imm_int(nir_builder *b, nir_ssa_def *src, uint64_t imm_val,
unsigned num_components)
{
assert(src->num_components <= num_components);
if (src->num_components == num_components)
return src;
nir_ssa_def *components[NIR_MAX_VEC_COMPONENTS];
nir_ssa_def *imm = nir_imm_intN_t(b, imm_val, src->bit_size);
unsigned i = 0;
for (; i < src->num_components; i++)
components[i] = nir_channel(b, src, i);
for (; i < num_components; i++)
components[i] = imm;
return nir_vec(b, components, num_components);
}
/**
* Pad a value to 4 components with undefs of matching bit size.
* If the value already contains >= 4 components, it is returned without change.

View File

@@ -23,6 +23,7 @@
#include "nir.h"
#include "nir_builder.h"
#include "util/u_math.h"
/**
* \file nir_opt_intrinsics.c
@@ -61,54 +62,24 @@ lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin)
}
static nir_ssa_def *
ballot_type_to_uint(nir_builder *b, nir_ssa_def *value, unsigned bit_size)
ballot_type_to_uint(nir_builder *b, nir_ssa_def *value,
const nir_lower_subgroups_options *options)
{
/* We only use this on uvec4 types */
/* Only the new-style SPIR-V subgroup instructions take a ballot result as
* an argument, so we only use this on uvec4 types.
*/
assert(value->num_components == 4 && value->bit_size == 32);
if (bit_size == 32) {
return nir_channel(b, value, 0);
} else {
assert(bit_size == 64);
return nir_pack_64_2x32_split(b, nir_channel(b, value, 0),
nir_channel(b, value, 1));
}
return nir_extract_bits(b, &value, 1, 0, options->ballot_components,
options->ballot_bit_size);
}
/* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
static nir_ssa_def *
uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
unsigned num_components, unsigned bit_size)
{
assert(value->num_components == 1);
assert(value->bit_size == 32 || value->bit_size == 64);
nir_ssa_def *zero = nir_imm_int(b, 0);
if (num_components > 1) {
/* SPIR-V uses a uvec4 for ballot values */
assert(num_components == 4);
assert(bit_size == 32);
if (value->bit_size == 32) {
return nir_vec4(b, value, zero, zero, zero);
} else {
assert(value->bit_size == 64);
return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value),
nir_unpack_64_2x32_split_y(b, value),
zero, zero);
}
} else {
/* GLSL uses a uint64_t for ballot values */
assert(num_components == 1);
assert(bit_size == 64);
if (value->bit_size == 32) {
return nir_pack_64_2x32_split(b, value, zero);
} else {
assert(value->bit_size == 64);
return value;
}
}
value = nir_bitcast_vector(b, value, bit_size);
return nir_pad_vector_imm_int(b, value, 0, num_components);
}
static nir_ssa_def *
@@ -317,13 +288,169 @@ lower_subgroups_filter(const nir_instr *instr, const void *_options)
return instr->type == nir_instr_type_intrinsic;
}
/* Return a ballot-mask-sized value which represents "val" sign-extended and
* then shifted left by "shift". Only particular values for "val" are
* supported, see below.
*/
static nir_ssa_def *
build_subgroup_mask(nir_builder *b, unsigned bit_size,
build_ballot_imm_ishl(nir_builder *b, int64_t val, nir_ssa_def *shift,
const nir_lower_subgroups_options *options)
{
/* This only works if all the high bits are the same as bit 1. */
assert(((val << 62) >> 62) == val);
/* First compute the result assuming one ballot component. */
nir_ssa_def *result =
nir_ishl(b, nir_imm_intN_t(b, val, options->ballot_bit_size), shift);
if (options->ballot_components == 1)
return result;
/* Fix up the result when there is > 1 component. The idea is that nir_ishl
* masks out the high bits of the shift value already, so in case there's
* more than one component the component which 1 would be shifted into
* already has the right value and all we have to do is fixup the other
* components. Components below it should always be 0, and components above
* it must be either 0 or ~0 because of the assert above. For example, if
* the target ballot size is 2 x uint32, and we're shifting 1 by 33, then
* we'll feed 33 into ishl, which will mask it off to get 1, so we'll
* compute a single-component result of 2, which is correct for the second
* component, but the first component needs to be 0, which we get by
* comparing the high bits of the shift with 0 and selecting the original
* answer or 0 for the first component (and something similar with the
* second component). This idea is generalized here for any component count
*/
nir_const_value min_shift[4] = { 0 };
for (unsigned i = 0; i < options->ballot_components; i++)
min_shift[i].i32 = i * options->ballot_bit_size;
nir_ssa_def *min_shift_val = nir_build_imm(b, options->ballot_components, 32, min_shift);
nir_const_value max_shift[4] = { 0 };
for (unsigned i = 0; i < options->ballot_components; i++)
max_shift[i].i32 = (i + 1) * options->ballot_bit_size;
nir_ssa_def *max_shift_val = nir_build_imm(b, options->ballot_components, 32, max_shift);
return nir_bcsel(b, nir_ult(b, shift, max_shift_val),
nir_bcsel(b, nir_ult(b, shift, min_shift_val),
nir_imm_intN_t(b, val >> 63, result->bit_size),
result),
nir_imm_intN_t(b, 0, result->bit_size));
}
static nir_ssa_def *
build_subgroup_eq_mask(nir_builder *b,
const nir_lower_subgroups_options *options)
{
nir_ssa_def *subgroup_idx = nir_load_subgroup_invocation(b);
return build_ballot_imm_ishl(b, 1, subgroup_idx, options);
}
static nir_ssa_def *
build_subgroup_ge_mask(nir_builder *b,
const nir_lower_subgroups_options *options)
{
nir_ssa_def *subgroup_idx = nir_load_subgroup_invocation(b);
return build_ballot_imm_ishl(b, ~0ull, subgroup_idx, options);
}
static nir_ssa_def *
build_subgroup_gt_mask(nir_builder *b,
const nir_lower_subgroups_options *options)
{
nir_ssa_def *subgroup_idx = nir_load_subgroup_invocation(b);
return build_ballot_imm_ishl(b, ~1ull, subgroup_idx, options);
}
/* Return a mask which is 1 for threads up to the run-time subgroup size, i.e.
* 1 for the entire subgroup. SPIR-V requires us to return 0 for indices at or
* above the subgroup size for the masks, but gt_mask and ge_mask make them 1
* so we have to "and" with this mask.
*/
static nir_ssa_def *
build_subgroup_mask(nir_builder *b,
const nir_lower_subgroups_options *options)
{
return nir_ushr(b, nir_imm_intN_t(b, ~0ull, bit_size),
nir_isub(b, nir_imm_int(b, bit_size),
nir_load_subgroup_size(b)));
nir_ssa_def *subgroup_size = nir_load_subgroup_size(b);
/* First compute the result assuming one ballot component. */
nir_ssa_def *result =
nir_ushr(b, nir_imm_intN_t(b, ~0ull, options->ballot_bit_size),
nir_isub_imm(b, options->ballot_bit_size,
subgroup_size));
/* Since the subgroup size and ballot bitsize are both powers of two, there
* are two possible cases to consider:
*
* (1) The subgroup size is less than the ballot bitsize. We need to return
* "result" in the first component and 0 in every other component.
* (2) The subgroup size is a multiple of the ballot bitsize. We need to
* return ~0 if the subgroup size divided by the ballot bitsize is less
* than or equal to the index in the vector and 0 otherwise. For example,
* with a target ballot type of 4 x uint32 and subgroup_size = 64 we'd need
* to return { ~0, ~0, 0, 0 }.
*
* In case (2) it turns out that "result" will be ~0, because
* "ballot_bit_size - subgroup_size" is also a multiple of
* "ballot_bit_size" and since nir_ushr masks the shift value it will
* shifted by 0. This means that the first component can just be "result"
* in all cases. The other components will also get the correct value in
* case (1) if we just use the rule in case (2), so we'll get the correct
* result if we just follow (2) and then replace the first component with
* "result".
*/
nir_const_value min_idx[4] = { 0 };
for (unsigned i = 0; i < options->ballot_components; i++)
min_idx[i].i32 = i * options->ballot_bit_size;
nir_ssa_def *min_idx_val = nir_build_imm(b, options->ballot_components, 32, min_idx);
nir_ssa_def *result_extended =
nir_pad_vector_imm_int(b, result, ~0ull, options->ballot_components);
return nir_bcsel(b, nir_ult(b, min_idx_val, subgroup_size),
result_extended, nir_imm_intN_t(b, 0, options->ballot_bit_size));
}
static nir_ssa_def *
vec_bit_count(nir_builder *b, nir_ssa_def *value)
{
nir_ssa_def *vec_result = nir_bit_count(b, value);
nir_ssa_def *result = nir_channel(b, vec_result, 0);
for (unsigned i = 1; i < value->num_components; i++)
result = nir_iadd(b, result, nir_channel(b, vec_result, i));
return result;
}
static nir_ssa_def *
vec_find_lsb(nir_builder *b, nir_ssa_def *value)
{
nir_ssa_def *vec_result = nir_find_lsb(b, value);
nir_ssa_def *result = nir_imm_int(b, -1);
for (int i = value->num_components - 1; i >= 0; i--) {
nir_ssa_def *channel = nir_channel(b, vec_result, i);
/* result = channel >= 0 ? (i * bitsize + channel) : result */
result = nir_bcsel(b, nir_ige(b, channel, nir_imm_int(b, 0)),
nir_iadd_imm(b, channel, i * value->bit_size),
result);
}
return result;
}
static nir_ssa_def *
vec_find_msb(nir_builder *b, nir_ssa_def *value)
{
nir_ssa_def *vec_result = nir_ufind_msb(b, value);
nir_ssa_def *result = nir_imm_int(b, -1);
for (unsigned i = 0; i < value->num_components; i++) {
nir_ssa_def *channel = nir_channel(b, vec_result, i);
/* result = channel >= 0 ? (i * bitsize + channel) : result */
result = nir_bcsel(b, nir_ige(b, channel, nir_imm_int(b, 0)),
nir_iadd_imm(b, channel, i * value->bit_size),
result);
}
return result;
}
static nir_ssa_def *
@@ -410,32 +537,24 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
if (!options->lower_subgroup_masks)
return NULL;
/* If either the result or the requested bit size is 64-bits then we
* know that we have 64-bit types and using them will probably be more
* efficient than messing around with 32-bit shifts and packing.
*/
const unsigned bit_size = MAX2(options->ballot_bit_size,
intrin->dest.ssa.bit_size);
nir_ssa_def *count = nir_load_subgroup_invocation(b);
nir_ssa_def *val;
switch (intrin->intrinsic) {
case nir_intrinsic_load_subgroup_eq_mask:
val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count);
val = build_subgroup_eq_mask(b, options);
break;
case nir_intrinsic_load_subgroup_ge_mask:
val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
build_subgroup_mask(b, bit_size, options));
val = nir_iand(b, build_subgroup_ge_mask(b, options),
build_subgroup_mask(b, options));
break;
case nir_intrinsic_load_subgroup_gt_mask:
val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
build_subgroup_mask(b, bit_size, options));
val = nir_iand(b, build_subgroup_gt_mask(b, options),
build_subgroup_mask(b, options));
break;
case nir_intrinsic_load_subgroup_le_mask:
val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
val = nir_inot(b, build_subgroup_gt_mask(b, options));
break;
case nir_intrinsic_load_subgroup_lt_mask:
val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count));
val = nir_inot(b, build_subgroup_ge_mask(b, options));
break;
default:
unreachable("you seriously can't tell this is unreachable?");
@@ -447,11 +566,13 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
}
case nir_intrinsic_ballot: {
if (intrin->dest.ssa.num_components == 1 &&
if (intrin->dest.ssa.num_components == options->ballot_components &&
intrin->dest.ssa.bit_size == options->ballot_bit_size)
return NULL;
nir_ssa_def *ballot = nir_ballot(b, 1, options->ballot_bit_size, intrin->src[0].ssa);
nir_ssa_def *ballot =
nir_ballot(b, options->ballot_components, options->ballot_bit_size,
intrin->src[0].ssa);
return uint_to_ballot_type(b, ballot,
intrin->dest.ssa.num_components,
@@ -464,7 +585,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
case nir_intrinsic_ballot_find_msb: {
assert(intrin->src[0].is_ssa);
nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
options->ballot_bit_size);
options);
if (intrin->intrinsic != nir_intrinsic_ballot_bitfield_extract &&
intrin->intrinsic != nir_intrinsic_ballot_find_lsb) {
@@ -487,22 +608,31 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
* valid bits, we hit the undefined results case and we can return
* anything we want.
*/
int_val = nir_iand(b, int_val,
build_subgroup_mask(b, options->ballot_bit_size, options));
int_val = nir_iand(b, int_val, build_subgroup_mask(b, options));
}
switch (intrin->intrinsic) {
case nir_intrinsic_ballot_bitfield_extract:
case nir_intrinsic_ballot_bitfield_extract: {
assert(intrin->src[1].is_ssa);
return nir_i2b(b, nir_iand(b, nir_ushr(b, int_val,
intrin->src[1].ssa),
nir_imm_intN_t(b, 1, options->ballot_bit_size)));
nir_ssa_def *idx = intrin->src[1].ssa;
if (int_val->num_components > 1) {
/* idx will be truncated by nir_ushr, so we just need to select
* the right component using the bits of idx that are truncated in
* the shift.
*/
int_val =
nir_vector_extract(b, int_val,
nir_udiv_imm(b, idx, int_val->bit_size));
}
return nir_i2b(b, nir_iand_imm(b, nir_ushr(b, int_val, idx), 1));
}
case nir_intrinsic_ballot_bit_count_reduce:
return nir_bit_count(b, int_val);
return vec_bit_count(b, int_val);
case nir_intrinsic_ballot_find_lsb:
return nir_find_lsb(b, int_val);
return vec_find_lsb(b, int_val);
case nir_intrinsic_ballot_find_msb:
return nir_ufind_msb(b, int_val);
return vec_find_msb(b, int_val);
default:
unreachable("you seriously can't tell this is unreachable?");
}
@@ -510,20 +640,18 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
case nir_intrinsic_ballot_bit_count_exclusive:
case nir_intrinsic_ballot_bit_count_inclusive: {
nir_ssa_def *count = nir_load_subgroup_invocation(b);
nir_ssa_def *mask = nir_imm_intN_t(b, ~0ull, options->ballot_bit_size);
nir_ssa_def *mask;
if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) {
const unsigned bits = options->ballot_bit_size;
mask = nir_ushr(b, mask, nir_isub(b, nir_imm_int(b, bits - 1), count));
mask = nir_inot(b, build_subgroup_gt_mask(b, options));
} else {
mask = nir_inot(b, nir_ishl(b, mask, count));
mask = nir_inot(b, build_subgroup_ge_mask(b, options));
}
assert(intrin->src[0].is_ssa);
nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
options->ballot_bit_size);
options);
return nir_bit_count(b, nir_iand(b, int_val, mask));
return vec_bit_count(b, nir_iand(b, int_val, mask));
}
case nir_intrinsic_elect: {

View File

@@ -2354,6 +2354,7 @@ void lp_build_opt_nir(struct nir_shader *nir)
const nir_lower_subgroups_options subgroups_options = {
.subgroup_size = lp_native_vector_width / 32,
.ballot_bit_size = 32,
.ballot_components = 1,
.lower_to_scalar = true,
.lower_subgroup_masks = true,
};

View File

@@ -3115,6 +3115,7 @@ Converter::run()
struct nir_lower_subgroups_options subgroup_options = {};
subgroup_options.subgroup_size = 32;
subgroup_options.ballot_bit_size = 32;
subgroup_options.ballot_components = 1;
subgroup_options.lower_elect = true;
/* prepare for IO lowering */

View File

@@ -821,6 +821,7 @@ static void si_lower_nir(struct si_screen *sscreen, struct nir_shader *nir)
const nir_lower_subgroups_options subgroups_options = {
.subgroup_size = 64,
.ballot_bit_size = 64,
.ballot_components = 1,
.lower_to_scalar = true,
.lower_subgroup_masks = true,
.lower_vote_trivial = false,

View File

@@ -874,6 +874,7 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir,
const nir_lower_subgroups_options subgroups_options = {
.ballot_bit_size = 32,
.ballot_components = 1,
.lower_to_scalar = true,
.lower_vote_trivial = !is_scalar,
.lower_shuffle = true,
@@ -1354,6 +1355,7 @@ brw_nir_apply_key(nir_shader *nir,
.subgroup_size = get_subgroup_size(nir->info.stage, key,
max_subgroup_size),
.ballot_bit_size = 32,
.ballot_components = 1,
.lower_subgroup_masks = true,
};
OPT(nir_lower_subgroups, &subgroups_options);