spirv: Add subgroup ballot support
Reviewed-by: Iago Toral Quiroga <itoral@igalia.com>
This commit is contained in:
@@ -45,6 +45,7 @@ struct spirv_supported_capabilities {
|
|||||||
bool variable_pointers;
|
bool variable_pointers;
|
||||||
bool storage_16bit;
|
bool storage_16bit;
|
||||||
bool shader_viewport_index_layer;
|
bool shader_viewport_index_layer;
|
||||||
|
bool subgroup_ballot;
|
||||||
bool subgroup_basic;
|
bool subgroup_basic;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -3296,6 +3296,11 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
|
|||||||
spv_check_supported(subgroup_basic, cap);
|
spv_check_supported(subgroup_basic, cap);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case SpvCapabilitySubgroupBallotKHR:
|
||||||
|
case SpvCapabilityGroupNonUniformBallot:
|
||||||
|
spv_check_supported(subgroup_ballot, cap);
|
||||||
|
break;
|
||||||
|
|
||||||
case SpvCapabilityVariablePointersStorageBuffer:
|
case SpvCapabilityVariablePointersStorageBuffer:
|
||||||
case SpvCapabilityVariablePointers:
|
case SpvCapabilityVariablePointers:
|
||||||
spv_check_supported(variable_pointers, cap);
|
spv_check_supported(variable_pointers, cap);
|
||||||
|
@@ -23,6 +23,44 @@
|
|||||||
|
|
||||||
#include "vtn_private.h"
|
#include "vtn_private.h"
|
||||||
|
|
||||||
|
static void
|
||||||
|
vtn_build_subgroup_instr(struct vtn_builder *b,
|
||||||
|
nir_intrinsic_op nir_op,
|
||||||
|
struct vtn_ssa_value *dst,
|
||||||
|
struct vtn_ssa_value *src0,
|
||||||
|
nir_ssa_def *index)
|
||||||
|
{
|
||||||
|
/* Some of the subgroup operations take an index. SPIR-V allows this to be
|
||||||
|
* any integer type. To make things simpler for drivers, we only support
|
||||||
|
* 32-bit indices.
|
||||||
|
*/
|
||||||
|
if (index && index->bit_size != 32)
|
||||||
|
index = nir_u2u32(&b->nb, index);
|
||||||
|
|
||||||
|
vtn_assert(dst->type == src0->type);
|
||||||
|
if (!glsl_type_is_vector_or_scalar(dst->type)) {
|
||||||
|
for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
|
||||||
|
vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
|
||||||
|
src0->elems[i], index);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
nir_intrinsic_instr *intrin =
|
||||||
|
nir_intrinsic_instr_create(b->nb.shader, nir_op);
|
||||||
|
nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
|
||||||
|
dst->type, NULL);
|
||||||
|
intrin->num_components = intrin->dest.ssa.num_components;
|
||||||
|
|
||||||
|
intrin->src[0] = nir_src_for_ssa(src0->def);
|
||||||
|
if (index)
|
||||||
|
intrin->src[1] = nir_src_for_ssa(index);
|
||||||
|
|
||||||
|
nir_builder_instr_insert(&b->nb, &intrin->instr);
|
||||||
|
|
||||||
|
dst->def = &intrin->dest.ssa;
|
||||||
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
||||||
const uint32_t *w, unsigned count)
|
const uint32_t *w, unsigned count)
|
||||||
@@ -43,17 +81,106 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case SpvOpGroupNonUniformAll:
|
case SpvOpGroupNonUniformBallot: {
|
||||||
case SpvOpGroupNonUniformAny:
|
vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
|
||||||
case SpvOpGroupNonUniformAllEqual:
|
"OpGroupNonUniformBallot must return a uvec4");
|
||||||
case SpvOpGroupNonUniformBroadcast:
|
nir_intrinsic_instr *ballot =
|
||||||
case SpvOpGroupNonUniformBroadcastFirst:
|
nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
|
||||||
case SpvOpGroupNonUniformBallot:
|
ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
|
||||||
case SpvOpGroupNonUniformInverseBallot:
|
nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
|
||||||
|
ballot->num_components = 4;
|
||||||
|
nir_builder_instr_insert(&b->nb, &ballot->instr);
|
||||||
|
val->ssa->def = &ballot->dest.ssa;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case SpvOpGroupNonUniformInverseBallot: {
|
||||||
|
/* This one is just a BallotBitfieldExtract with subgroup invocation.
|
||||||
|
* We could add a NIR intrinsic but it's easier to just lower it on the
|
||||||
|
* spot.
|
||||||
|
*/
|
||||||
|
nir_intrinsic_instr *intrin =
|
||||||
|
nir_intrinsic_instr_create(b->nb.shader,
|
||||||
|
nir_intrinsic_ballot_bitfield_extract);
|
||||||
|
|
||||||
|
intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
|
||||||
|
intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
|
||||||
|
|
||||||
|
nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
|
||||||
|
nir_builder_instr_insert(&b->nb, &intrin->instr);
|
||||||
|
|
||||||
|
val->ssa->def = &intrin->dest.ssa;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
case SpvOpGroupNonUniformBallotBitExtract:
|
case SpvOpGroupNonUniformBallotBitExtract:
|
||||||
case SpvOpGroupNonUniformBallotBitCount:
|
case SpvOpGroupNonUniformBallotBitCount:
|
||||||
case SpvOpGroupNonUniformBallotFindLSB:
|
case SpvOpGroupNonUniformBallotFindLSB:
|
||||||
case SpvOpGroupNonUniformBallotFindMSB:
|
case SpvOpGroupNonUniformBallotFindMSB: {
|
||||||
|
nir_ssa_def *src0, *src1 = NULL;
|
||||||
|
nir_intrinsic_op op;
|
||||||
|
switch (opcode) {
|
||||||
|
case SpvOpGroupNonUniformBallotBitExtract:
|
||||||
|
op = nir_intrinsic_ballot_bitfield_extract;
|
||||||
|
src0 = vtn_ssa_value(b, w[4])->def;
|
||||||
|
src1 = vtn_ssa_value(b, w[5])->def;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformBallotBitCount:
|
||||||
|
switch ((SpvGroupOperation)w[4]) {
|
||||||
|
case SpvGroupOperationReduce:
|
||||||
|
op = nir_intrinsic_ballot_bit_count_reduce;
|
||||||
|
break;
|
||||||
|
case SpvGroupOperationInclusiveScan:
|
||||||
|
op = nir_intrinsic_ballot_bit_count_inclusive;
|
||||||
|
break;
|
||||||
|
case SpvGroupOperationExclusiveScan:
|
||||||
|
op = nir_intrinsic_ballot_bit_count_exclusive;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
unreachable("Invalid group operation");
|
||||||
|
}
|
||||||
|
src0 = vtn_ssa_value(b, w[5])->def;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformBallotFindLSB:
|
||||||
|
op = nir_intrinsic_ballot_find_lsb;
|
||||||
|
src0 = vtn_ssa_value(b, w[4])->def;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformBallotFindMSB:
|
||||||
|
op = nir_intrinsic_ballot_find_msb;
|
||||||
|
src0 = vtn_ssa_value(b, w[4])->def;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
unreachable("Unhandled opcode");
|
||||||
|
}
|
||||||
|
|
||||||
|
nir_intrinsic_instr *intrin =
|
||||||
|
nir_intrinsic_instr_create(b->nb.shader, op);
|
||||||
|
|
||||||
|
intrin->src[0] = nir_src_for_ssa(src0);
|
||||||
|
if (src1)
|
||||||
|
intrin->src[1] = nir_src_for_ssa(src1);
|
||||||
|
|
||||||
|
nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
|
||||||
|
nir_builder_instr_insert(&b->nb, &intrin->instr);
|
||||||
|
|
||||||
|
val->ssa->def = &intrin->dest.ssa;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case SpvOpGroupNonUniformBroadcastFirst:
|
||||||
|
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
|
||||||
|
val->ssa, vtn_ssa_value(b, w[4]), NULL);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case SpvOpGroupNonUniformBroadcast:
|
||||||
|
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
|
||||||
|
val->ssa, vtn_ssa_value(b, w[4]),
|
||||||
|
vtn_ssa_value(b, w[5])->def);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case SpvOpGroupNonUniformAll:
|
||||||
|
case SpvOpGroupNonUniformAny:
|
||||||
|
case SpvOpGroupNonUniformAllEqual:
|
||||||
case SpvOpGroupNonUniformShuffle:
|
case SpvOpGroupNonUniformShuffle:
|
||||||
case SpvOpGroupNonUniformShuffleXor:
|
case SpvOpGroupNonUniformShuffleXor:
|
||||||
case SpvOpGroupNonUniformShuffleUp:
|
case SpvOpGroupNonUniformShuffleUp:
|
||||||
|
@@ -1317,6 +1317,26 @@ vtn_get_builtin_location(struct vtn_builder *b,
|
|||||||
*location = SYSTEM_VALUE_VIEW_INDEX;
|
*location = SYSTEM_VALUE_VIEW_INDEX;
|
||||||
set_mode_system_value(b, mode);
|
set_mode_system_value(b, mode);
|
||||||
break;
|
break;
|
||||||
|
case SpvBuiltInSubgroupEqMask:
|
||||||
|
*location = SYSTEM_VALUE_SUBGROUP_EQ_MASK,
|
||||||
|
set_mode_system_value(b, mode);
|
||||||
|
break;
|
||||||
|
case SpvBuiltInSubgroupGeMask:
|
||||||
|
*location = SYSTEM_VALUE_SUBGROUP_GE_MASK,
|
||||||
|
set_mode_system_value(b, mode);
|
||||||
|
break;
|
||||||
|
case SpvBuiltInSubgroupGtMask:
|
||||||
|
*location = SYSTEM_VALUE_SUBGROUP_GT_MASK,
|
||||||
|
set_mode_system_value(b, mode);
|
||||||
|
break;
|
||||||
|
case SpvBuiltInSubgroupLeMask:
|
||||||
|
*location = SYSTEM_VALUE_SUBGROUP_LE_MASK,
|
||||||
|
set_mode_system_value(b, mode);
|
||||||
|
break;
|
||||||
|
case SpvBuiltInSubgroupLtMask:
|
||||||
|
*location = SYSTEM_VALUE_SUBGROUP_LT_MASK,
|
||||||
|
set_mode_system_value(b, mode);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
vtn_fail("unsupported builtin");
|
vtn_fail("unsupported builtin");
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user