spirv: Add support for subgroup arithmetic
Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com> Reviewed-by: Iago Toral Quiroga <itoral@igalia.com>
This commit is contained in:
@@ -45,6 +45,7 @@ struct spirv_supported_capabilities {
|
||||
bool variable_pointers;
|
||||
bool storage_16bit;
|
||||
bool shader_viewport_index_layer;
|
||||
bool subgroup_arithmetic;
|
||||
bool subgroup_ballot;
|
||||
bool subgroup_basic;
|
||||
bool subgroup_quad;
|
||||
|
@@ -3313,6 +3313,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
|
||||
case SpvCapabilityGroupNonUniformQuad:
|
||||
spv_check_supported(subgroup_quad, cap);
|
||||
|
||||
case SpvCapabilityGroupNonUniformArithmetic:
|
||||
case SpvCapabilityGroupNonUniformClustered:
|
||||
spv_check_supported(subgroup_arithmetic, cap);
|
||||
|
||||
case SpvCapabilityVariablePointersStorageBuffer:
|
||||
case SpvCapabilityVariablePointers:
|
||||
spv_check_supported(variable_pointers, cap);
|
||||
|
@@ -28,7 +28,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
|
||||
nir_intrinsic_op nir_op,
|
||||
struct vtn_ssa_value *dst,
|
||||
struct vtn_ssa_value *src0,
|
||||
nir_ssa_def *index)
|
||||
nir_ssa_def *index,
|
||||
unsigned const_idx0,
|
||||
unsigned const_idx1)
|
||||
{
|
||||
/* Some of the subgroup operations take an index. SPIR-V allows this to be
|
||||
* any integer type. To make things simpler for drivers, we only support
|
||||
@@ -41,7 +43,8 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
|
||||
if (!glsl_type_is_vector_or_scalar(dst->type)) {
|
||||
for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
|
||||
vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
|
||||
src0->elems[i], index);
|
||||
src0->elems[i], index,
|
||||
const_idx0, const_idx1);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -56,6 +59,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
|
||||
if (index)
|
||||
intrin->src[1] = nir_src_for_ssa(index);
|
||||
|
||||
intrin->const_index[0] = const_idx0;
|
||||
intrin->const_index[1] = const_idx1;
|
||||
|
||||
nir_builder_instr_insert(&b->nb, &intrin->instr);
|
||||
|
||||
dst->def = &intrin->dest.ssa;
|
||||
@@ -169,13 +175,13 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
||||
|
||||
case SpvOpGroupNonUniformBroadcastFirst:
|
||||
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
|
||||
val->ssa, vtn_ssa_value(b, w[4]), NULL);
|
||||
val->ssa, vtn_ssa_value(b, w[4]), NULL, 0, 0);
|
||||
break;
|
||||
|
||||
case SpvOpGroupNonUniformBroadcast:
|
||||
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
|
||||
val->ssa, vtn_ssa_value(b, w[4]),
|
||||
vtn_ssa_value(b, w[5])->def);
|
||||
vtn_ssa_value(b, w[5])->def, 0, 0);
|
||||
break;
|
||||
|
||||
case SpvOpGroupNonUniformAll:
|
||||
@@ -248,14 +254,14 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
||||
unreachable("Invalid opcode");
|
||||
}
|
||||
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
|
||||
vtn_ssa_value(b, w[5])->def);
|
||||
vtn_ssa_value(b, w[5])->def, 0, 0);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpGroupNonUniformQuadBroadcast:
|
||||
vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
|
||||
val->ssa, vtn_ssa_value(b, w[4]),
|
||||
vtn_ssa_value(b, w[5])->def);
|
||||
vtn_ssa_value(b, w[5])->def, 0, 0);
|
||||
break;
|
||||
|
||||
case SpvOpGroupNonUniformQuadSwap: {
|
||||
@@ -272,7 +278,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
||||
op = nir_intrinsic_quad_swap_diagonal;
|
||||
break;
|
||||
}
|
||||
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), NULL);
|
||||
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
|
||||
NULL, 0, 0);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -291,7 +298,81 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
|
||||
case SpvOpGroupNonUniformBitwiseXor:
|
||||
case SpvOpGroupNonUniformLogicalAnd:
|
||||
case SpvOpGroupNonUniformLogicalOr:
|
||||
case SpvOpGroupNonUniformLogicalXor: {
|
||||
nir_op reduction_op;
|
||||
switch (opcode) {
|
||||
case SpvOpGroupNonUniformIAdd:
|
||||
reduction_op = nir_op_iadd;
|
||||
break;
|
||||
case SpvOpGroupNonUniformFAdd:
|
||||
reduction_op = nir_op_fadd;
|
||||
break;
|
||||
case SpvOpGroupNonUniformIMul:
|
||||
reduction_op = nir_op_imul;
|
||||
break;
|
||||
case SpvOpGroupNonUniformFMul:
|
||||
reduction_op = nir_op_fmul;
|
||||
break;
|
||||
case SpvOpGroupNonUniformSMin:
|
||||
reduction_op = nir_op_imin;
|
||||
break;
|
||||
case SpvOpGroupNonUniformUMin:
|
||||
reduction_op = nir_op_umin;
|
||||
break;
|
||||
case SpvOpGroupNonUniformFMin:
|
||||
reduction_op = nir_op_fmin;
|
||||
break;
|
||||
case SpvOpGroupNonUniformSMax:
|
||||
reduction_op = nir_op_imax;
|
||||
break;
|
||||
case SpvOpGroupNonUniformUMax:
|
||||
reduction_op = nir_op_umax;
|
||||
break;
|
||||
case SpvOpGroupNonUniformFMax:
|
||||
reduction_op = nir_op_fmax;
|
||||
break;
|
||||
case SpvOpGroupNonUniformBitwiseAnd:
|
||||
case SpvOpGroupNonUniformLogicalAnd:
|
||||
reduction_op = nir_op_iand;
|
||||
break;
|
||||
case SpvOpGroupNonUniformBitwiseOr:
|
||||
case SpvOpGroupNonUniformLogicalOr:
|
||||
reduction_op = nir_op_ior;
|
||||
break;
|
||||
case SpvOpGroupNonUniformBitwiseXor:
|
||||
case SpvOpGroupNonUniformLogicalXor:
|
||||
reduction_op = nir_op_ixor;
|
||||
break;
|
||||
default:
|
||||
unreachable("Invalid reduction operation");
|
||||
}
|
||||
|
||||
nir_intrinsic_op op;
|
||||
unsigned cluster_size = 0;
|
||||
switch ((SpvGroupOperation)w[4]) {
|
||||
case SpvGroupOperationReduce:
|
||||
op = nir_intrinsic_reduce;
|
||||
break;
|
||||
case SpvGroupOperationInclusiveScan:
|
||||
op = nir_intrinsic_inclusive_scan;
|
||||
break;
|
||||
case SpvGroupOperationExclusiveScan:
|
||||
op = nir_intrinsic_exclusive_scan;
|
||||
break;
|
||||
case SpvGroupOperationClusteredReduce:
|
||||
op = nir_intrinsic_reduce;
|
||||
assert(count == 7);
|
||||
cluster_size = vtn_constant_value(b, w[6])->values[0].u32[0];
|
||||
break;
|
||||
default:
|
||||
unreachable("Invalid group operation");
|
||||
}
|
||||
|
||||
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
|
||||
NULL, reduction_op, cluster_size);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
unreachable("Invalid SPIR-V opcode");
|
||||
}
|
||||
|
Reference in New Issue
Block a user