spirv: Add support for subgroup arithmetic
Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com> Reviewed-by: Iago Toral Quiroga <itoral@igalia.com>
This commit is contained in:
@@ -45,6 +45,7 @@ struct spirv_supported_capabilities {
|
|||||||
bool variable_pointers;
|
bool variable_pointers;
|
||||||
bool storage_16bit;
|
bool storage_16bit;
|
||||||
bool shader_viewport_index_layer;
|
bool shader_viewport_index_layer;
|
||||||
|
bool subgroup_arithmetic;
|
||||||
bool subgroup_ballot;
|
bool subgroup_ballot;
|
||||||
bool subgroup_basic;
|
bool subgroup_basic;
|
||||||
bool subgroup_quad;
|
bool subgroup_quad;
|
||||||
|
@@ -3313,6 +3313,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
|
|||||||
case SpvCapabilityGroupNonUniformQuad:
|
case SpvCapabilityGroupNonUniformQuad:
|
||||||
spv_check_supported(subgroup_quad, cap);
|
spv_check_supported(subgroup_quad, cap);
|
||||||
|
|
||||||
|
case SpvCapabilityGroupNonUniformArithmetic:
|
||||||
|
case SpvCapabilityGroupNonUniformClustered:
|
||||||
|
spv_check_supported(subgroup_arithmetic, cap);
|
||||||
|
|
||||||
case SpvCapabilityVariablePointersStorageBuffer:
|
case SpvCapabilityVariablePointersStorageBuffer:
|
||||||
case SpvCapabilityVariablePointers:
|
case SpvCapabilityVariablePointers:
|
||||||
spv_check_supported(variable_pointers, cap);
|
spv_check_supported(variable_pointers, cap);
|
||||||
|
@@ -28,7 +28,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
|
|||||||
nir_intrinsic_op nir_op,
|
nir_intrinsic_op nir_op,
|
||||||
struct vtn_ssa_value *dst,
|
struct vtn_ssa_value *dst,
|
||||||
struct vtn_ssa_value *src0,
|
struct vtn_ssa_value *src0,
|
||||||
nir_ssa_def *index)
|
nir_ssa_def *index,
|
||||||
|
unsigned const_idx0,
|
||||||
|
unsigned const_idx1)
|
||||||
{
|
{
|
||||||
/* Some of the subgroup operations take an index. SPIR-V allows this to be
|
/* Some of the subgroup operations take an index. SPIR-V allows this to be
|
||||||
* any integer type. To make things simpler for drivers, we only support
|
* any integer type. To make things simpler for drivers, we only support
|
||||||
@@ -41,7 +43,8 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
|
|||||||
if (!glsl_type_is_vector_or_scalar(dst->type)) {
|
if (!glsl_type_is_vector_or_scalar(dst->type)) {
|
||||||
for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
|
for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
|
||||||
vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
|
vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
|
||||||
src0->elems[i], index);
|
src0->elems[i], index,
|
||||||
|
const_idx0, const_idx1);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -56,6 +59,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
|
|||||||
if (index)
|
if (index)
|
||||||
intrin->src[1] = nir_src_for_ssa(index);
|
intrin->src[1] = nir_src_for_ssa(index);
|
||||||
|
|
||||||
|
intrin->const_index[0] = const_idx0;
|
||||||
|
intrin->const_index[1] = const_idx1;
|
||||||
|
|
||||||
nir_builder_instr_insert(&b->nb, &intrin->instr);
|
nir_builder_instr_insert(&b->nb, &intrin->instr);
|
||||||
|
|
||||||
dst->def = &intrin->dest.ssa;
|
dst->def = &intrin->dest.ssa;
|
||||||
@@ -169,13 +175,13 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||||||
|
|
||||||
case SpvOpGroupNonUniformBroadcastFirst:
|
case SpvOpGroupNonUniformBroadcastFirst:
|
||||||
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
|
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
|
||||||
val->ssa, vtn_ssa_value(b, w[4]), NULL);
|
val->ssa, vtn_ssa_value(b, w[4]), NULL, 0, 0);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case SpvOpGroupNonUniformBroadcast:
|
case SpvOpGroupNonUniformBroadcast:
|
||||||
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
|
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
|
||||||
val->ssa, vtn_ssa_value(b, w[4]),
|
val->ssa, vtn_ssa_value(b, w[4]),
|
||||||
vtn_ssa_value(b, w[5])->def);
|
vtn_ssa_value(b, w[5])->def, 0, 0);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case SpvOpGroupNonUniformAll:
|
case SpvOpGroupNonUniformAll:
|
||||||
@@ -248,14 +254,14 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||||||
unreachable("Invalid opcode");
|
unreachable("Invalid opcode");
|
||||||
}
|
}
|
||||||
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
|
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
|
||||||
vtn_ssa_value(b, w[5])->def);
|
vtn_ssa_value(b, w[5])->def, 0, 0);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case SpvOpGroupNonUniformQuadBroadcast:
|
case SpvOpGroupNonUniformQuadBroadcast:
|
||||||
vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
|
vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
|
||||||
val->ssa, vtn_ssa_value(b, w[4]),
|
val->ssa, vtn_ssa_value(b, w[4]),
|
||||||
vtn_ssa_value(b, w[5])->def);
|
vtn_ssa_value(b, w[5])->def, 0, 0);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case SpvOpGroupNonUniformQuadSwap: {
|
case SpvOpGroupNonUniformQuadSwap: {
|
||||||
@@ -272,7 +278,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||||||
op = nir_intrinsic_quad_swap_diagonal;
|
op = nir_intrinsic_quad_swap_diagonal;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), NULL);
|
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
|
||||||
|
NULL, 0, 0);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -291,7 +298,81 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
|||||||
case SpvOpGroupNonUniformBitwiseXor:
|
case SpvOpGroupNonUniformBitwiseXor:
|
||||||
case SpvOpGroupNonUniformLogicalAnd:
|
case SpvOpGroupNonUniformLogicalAnd:
|
||||||
case SpvOpGroupNonUniformLogicalOr:
|
case SpvOpGroupNonUniformLogicalOr:
|
||||||
case SpvOpGroupNonUniformLogicalXor:
|
case SpvOpGroupNonUniformLogicalXor: {
|
||||||
|
nir_op reduction_op;
|
||||||
|
switch (opcode) {
|
||||||
|
case SpvOpGroupNonUniformIAdd:
|
||||||
|
reduction_op = nir_op_iadd;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformFAdd:
|
||||||
|
reduction_op = nir_op_fadd;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformIMul:
|
||||||
|
reduction_op = nir_op_imul;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformFMul:
|
||||||
|
reduction_op = nir_op_fmul;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformSMin:
|
||||||
|
reduction_op = nir_op_imin;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformUMin:
|
||||||
|
reduction_op = nir_op_umin;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformFMin:
|
||||||
|
reduction_op = nir_op_fmin;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformSMax:
|
||||||
|
reduction_op = nir_op_imax;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformUMax:
|
||||||
|
reduction_op = nir_op_umax;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformFMax:
|
||||||
|
reduction_op = nir_op_fmax;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformBitwiseAnd:
|
||||||
|
case SpvOpGroupNonUniformLogicalAnd:
|
||||||
|
reduction_op = nir_op_iand;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformBitwiseOr:
|
||||||
|
case SpvOpGroupNonUniformLogicalOr:
|
||||||
|
reduction_op = nir_op_ior;
|
||||||
|
break;
|
||||||
|
case SpvOpGroupNonUniformBitwiseXor:
|
||||||
|
case SpvOpGroupNonUniformLogicalXor:
|
||||||
|
reduction_op = nir_op_ixor;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
unreachable("Invalid reduction operation");
|
||||||
|
}
|
||||||
|
|
||||||
|
nir_intrinsic_op op;
|
||||||
|
unsigned cluster_size = 0;
|
||||||
|
switch ((SpvGroupOperation)w[4]) {
|
||||||
|
case SpvGroupOperationReduce:
|
||||||
|
op = nir_intrinsic_reduce;
|
||||||
|
break;
|
||||||
|
case SpvGroupOperationInclusiveScan:
|
||||||
|
op = nir_intrinsic_inclusive_scan;
|
||||||
|
break;
|
||||||
|
case SpvGroupOperationExclusiveScan:
|
||||||
|
op = nir_intrinsic_exclusive_scan;
|
||||||
|
break;
|
||||||
|
case SpvGroupOperationClusteredReduce:
|
||||||
|
op = nir_intrinsic_reduce;
|
||||||
|
assert(count == 7);
|
||||||
|
cluster_size = vtn_constant_value(b, w[6])->values[0].u32[0];
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
unreachable("Invalid group operation");
|
||||||
|
}
|
||||||
|
|
||||||
|
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
|
||||||
|
NULL, reduction_op, cluster_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
unreachable("Invalid SPIR-V opcode");
|
unreachable("Invalid SPIR-V opcode");
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user