spirv: Add subgroup ballot support

Reviewed-by: Iago Toral Quiroga <itoral@igalia.com>
This commit is contained in:
Jason Ekstrand
2017-08-22 16:53:05 -07:00
parent 974daec495
commit 9812fce60b
4 changed files with 161 additions and 8 deletions

View File

@@ -45,6 +45,7 @@ struct spirv_supported_capabilities {
bool variable_pointers; bool variable_pointers;
bool storage_16bit; bool storage_16bit;
bool shader_viewport_index_layer; bool shader_viewport_index_layer;
bool subgroup_ballot;
bool subgroup_basic; bool subgroup_basic;
}; };

View File

@@ -3296,6 +3296,11 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
spv_check_supported(subgroup_basic, cap); spv_check_supported(subgroup_basic, cap);
break; break;
case SpvCapabilitySubgroupBallotKHR:
case SpvCapabilityGroupNonUniformBallot:
spv_check_supported(subgroup_ballot, cap);
break;
case SpvCapabilityVariablePointersStorageBuffer: case SpvCapabilityVariablePointersStorageBuffer:
case SpvCapabilityVariablePointers: case SpvCapabilityVariablePointers:
spv_check_supported(variable_pointers, cap); spv_check_supported(variable_pointers, cap);

View File

@@ -23,6 +23,44 @@
#include "vtn_private.h" #include "vtn_private.h"
static void
vtn_build_subgroup_instr(struct vtn_builder *b,
nir_intrinsic_op nir_op,
struct vtn_ssa_value *dst,
struct vtn_ssa_value *src0,
nir_ssa_def *index)
{
/* Some of the subgroup operations take an index. SPIR-V allows this to be
* any integer type. To make things simpler for drivers, we only support
* 32-bit indices.
*/
if (index && index->bit_size != 32)
index = nir_u2u32(&b->nb, index);
vtn_assert(dst->type == src0->type);
if (!glsl_type_is_vector_or_scalar(dst->type)) {
for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
src0->elems[i], index);
}
return;
}
nir_intrinsic_instr *intrin =
nir_intrinsic_instr_create(b->nb.shader, nir_op);
nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
dst->type, NULL);
intrin->num_components = intrin->dest.ssa.num_components;
intrin->src[0] = nir_src_for_ssa(src0->def);
if (index)
intrin->src[1] = nir_src_for_ssa(index);
nir_builder_instr_insert(&b->nb, &intrin->instr);
dst->def = &intrin->dest.ssa;
}
void void
vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count) const uint32_t *w, unsigned count)
@@ -43,17 +81,106 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
break; break;
} }
case SpvOpGroupNonUniformAll: case SpvOpGroupNonUniformBallot: {
case SpvOpGroupNonUniformAny: vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
case SpvOpGroupNonUniformAllEqual: "OpGroupNonUniformBallot must return a uvec4");
case SpvOpGroupNonUniformBroadcast: nir_intrinsic_instr *ballot =
case SpvOpGroupNonUniformBroadcastFirst: nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
case SpvOpGroupNonUniformBallot: ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
case SpvOpGroupNonUniformInverseBallot: nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
ballot->num_components = 4;
nir_builder_instr_insert(&b->nb, &ballot->instr);
val->ssa->def = &ballot->dest.ssa;
break;
}
case SpvOpGroupNonUniformInverseBallot: {
/* This one is just a BallotBitfieldExtract with subgroup invocation.
* We could add a NIR intrinsic but it's easier to just lower it on the
* spot.
*/
nir_intrinsic_instr *intrin =
nir_intrinsic_instr_create(b->nb.shader,
nir_intrinsic_ballot_bitfield_extract);
intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
nir_builder_instr_insert(&b->nb, &intrin->instr);
val->ssa->def = &intrin->dest.ssa;
break;
}
case SpvOpGroupNonUniformBallotBitExtract: case SpvOpGroupNonUniformBallotBitExtract:
case SpvOpGroupNonUniformBallotBitCount: case SpvOpGroupNonUniformBallotBitCount:
case SpvOpGroupNonUniformBallotFindLSB: case SpvOpGroupNonUniformBallotFindLSB:
case SpvOpGroupNonUniformBallotFindMSB: case SpvOpGroupNonUniformBallotFindMSB: {
nir_ssa_def *src0, *src1 = NULL;
nir_intrinsic_op op;
switch (opcode) {
case SpvOpGroupNonUniformBallotBitExtract:
op = nir_intrinsic_ballot_bitfield_extract;
src0 = vtn_ssa_value(b, w[4])->def;
src1 = vtn_ssa_value(b, w[5])->def;
break;
case SpvOpGroupNonUniformBallotBitCount:
switch ((SpvGroupOperation)w[4]) {
case SpvGroupOperationReduce:
op = nir_intrinsic_ballot_bit_count_reduce;
break;
case SpvGroupOperationInclusiveScan:
op = nir_intrinsic_ballot_bit_count_inclusive;
break;
case SpvGroupOperationExclusiveScan:
op = nir_intrinsic_ballot_bit_count_exclusive;
break;
default:
unreachable("Invalid group operation");
}
src0 = vtn_ssa_value(b, w[5])->def;
break;
case SpvOpGroupNonUniformBallotFindLSB:
op = nir_intrinsic_ballot_find_lsb;
src0 = vtn_ssa_value(b, w[4])->def;
break;
case SpvOpGroupNonUniformBallotFindMSB:
op = nir_intrinsic_ballot_find_msb;
src0 = vtn_ssa_value(b, w[4])->def;
break;
default:
unreachable("Unhandled opcode");
}
nir_intrinsic_instr *intrin =
nir_intrinsic_instr_create(b->nb.shader, op);
intrin->src[0] = nir_src_for_ssa(src0);
if (src1)
intrin->src[1] = nir_src_for_ssa(src1);
nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
nir_builder_instr_insert(&b->nb, &intrin->instr);
val->ssa->def = &intrin->dest.ssa;
break;
}
case SpvOpGroupNonUniformBroadcastFirst:
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
val->ssa, vtn_ssa_value(b, w[4]), NULL);
break;
case SpvOpGroupNonUniformBroadcast:
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
val->ssa, vtn_ssa_value(b, w[4]),
vtn_ssa_value(b, w[5])->def);
break;
case SpvOpGroupNonUniformAll:
case SpvOpGroupNonUniformAny:
case SpvOpGroupNonUniformAllEqual:
case SpvOpGroupNonUniformShuffle: case SpvOpGroupNonUniformShuffle:
case SpvOpGroupNonUniformShuffleXor: case SpvOpGroupNonUniformShuffleXor:
case SpvOpGroupNonUniformShuffleUp: case SpvOpGroupNonUniformShuffleUp:

View File

@@ -1317,6 +1317,26 @@ vtn_get_builtin_location(struct vtn_builder *b,
*location = SYSTEM_VALUE_VIEW_INDEX; *location = SYSTEM_VALUE_VIEW_INDEX;
set_mode_system_value(b, mode); set_mode_system_value(b, mode);
break; break;
case SpvBuiltInSubgroupEqMask:
*location = SYSTEM_VALUE_SUBGROUP_EQ_MASK,
set_mode_system_value(b, mode);
break;
case SpvBuiltInSubgroupGeMask:
*location = SYSTEM_VALUE_SUBGROUP_GE_MASK,
set_mode_system_value(b, mode);
break;
case SpvBuiltInSubgroupGtMask:
*location = SYSTEM_VALUE_SUBGROUP_GT_MASK,
set_mode_system_value(b, mode);
break;
case SpvBuiltInSubgroupLeMask:
*location = SYSTEM_VALUE_SUBGROUP_LE_MASK,
set_mode_system_value(b, mode);
break;
case SpvBuiltInSubgroupLtMask:
*location = SYSTEM_VALUE_SUBGROUP_LT_MASK,
set_mode_system_value(b, mode);
break;
default: default:
vtn_fail("unsupported builtin"); vtn_fail("unsupported builtin");
} }