spirv: Add support for subgroup arithmetic

Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Reviewed-by: Iago Toral Quiroga <itoral@igalia.com>
This commit is contained in:
Jason Ekstrand
2017-08-29 20:10:35 -07:00
parent 789221dcfa
commit 57bff0a546
3 changed files with 94 additions and 8 deletions

View File

@@ -45,6 +45,7 @@ struct spirv_supported_capabilities {
bool variable_pointers; bool variable_pointers;
bool storage_16bit; bool storage_16bit;
bool shader_viewport_index_layer; bool shader_viewport_index_layer;
bool subgroup_arithmetic;
bool subgroup_ballot; bool subgroup_ballot;
bool subgroup_basic; bool subgroup_basic;
bool subgroup_quad; bool subgroup_quad;

View File

@@ -3313,6 +3313,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvCapabilityGroupNonUniformQuad: case SpvCapabilityGroupNonUniformQuad:
spv_check_supported(subgroup_quad, cap); spv_check_supported(subgroup_quad, cap);
case SpvCapabilityGroupNonUniformArithmetic:
case SpvCapabilityGroupNonUniformClustered:
spv_check_supported(subgroup_arithmetic, cap);
case SpvCapabilityVariablePointersStorageBuffer: case SpvCapabilityVariablePointersStorageBuffer:
case SpvCapabilityVariablePointers: case SpvCapabilityVariablePointers:
spv_check_supported(variable_pointers, cap); spv_check_supported(variable_pointers, cap);

View File

@@ -28,7 +28,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
nir_intrinsic_op nir_op, nir_intrinsic_op nir_op,
struct vtn_ssa_value *dst, struct vtn_ssa_value *dst,
struct vtn_ssa_value *src0, struct vtn_ssa_value *src0,
nir_ssa_def *index) nir_ssa_def *index,
unsigned const_idx0,
unsigned const_idx1)
{ {
/* Some of the subgroup operations take an index. SPIR-V allows this to be /* Some of the subgroup operations take an index. SPIR-V allows this to be
* any integer type. To make things simpler for drivers, we only support * any integer type. To make things simpler for drivers, we only support
@@ -41,7 +43,8 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
if (!glsl_type_is_vector_or_scalar(dst->type)) { if (!glsl_type_is_vector_or_scalar(dst->type)) {
for (unsigned i = 0; i < glsl_get_length(dst->type); i++) { for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
vtn_build_subgroup_instr(b, nir_op, dst->elems[i], vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
src0->elems[i], index); src0->elems[i], index,
const_idx0, const_idx1);
} }
return; return;
} }
@@ -56,6 +59,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
if (index) if (index)
intrin->src[1] = nir_src_for_ssa(index); intrin->src[1] = nir_src_for_ssa(index);
intrin->const_index[0] = const_idx0;
intrin->const_index[1] = const_idx1;
nir_builder_instr_insert(&b->nb, &intrin->instr); nir_builder_instr_insert(&b->nb, &intrin->instr);
dst->def = &intrin->dest.ssa; dst->def = &intrin->dest.ssa;
@@ -169,13 +175,13 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
case SpvOpGroupNonUniformBroadcastFirst: case SpvOpGroupNonUniformBroadcastFirst:
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation, vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
val->ssa, vtn_ssa_value(b, w[4]), NULL); val->ssa, vtn_ssa_value(b, w[4]), NULL, 0, 0);
break; break;
case SpvOpGroupNonUniformBroadcast: case SpvOpGroupNonUniformBroadcast:
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation, vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
val->ssa, vtn_ssa_value(b, w[4]), val->ssa, vtn_ssa_value(b, w[4]),
vtn_ssa_value(b, w[5])->def); vtn_ssa_value(b, w[5])->def, 0, 0);
break; break;
case SpvOpGroupNonUniformAll: case SpvOpGroupNonUniformAll:
@@ -248,14 +254,14 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
unreachable("Invalid opcode"); unreachable("Invalid opcode");
} }
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
vtn_ssa_value(b, w[5])->def); vtn_ssa_value(b, w[5])->def, 0, 0);
break; break;
} }
case SpvOpGroupNonUniformQuadBroadcast: case SpvOpGroupNonUniformQuadBroadcast:
vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast, vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
val->ssa, vtn_ssa_value(b, w[4]), val->ssa, vtn_ssa_value(b, w[4]),
vtn_ssa_value(b, w[5])->def); vtn_ssa_value(b, w[5])->def, 0, 0);
break; break;
case SpvOpGroupNonUniformQuadSwap: { case SpvOpGroupNonUniformQuadSwap: {
@@ -272,7 +278,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
op = nir_intrinsic_quad_swap_diagonal; op = nir_intrinsic_quad_swap_diagonal;
break; break;
} }
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), NULL); vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
NULL, 0, 0);
break; break;
} }
@@ -291,7 +298,81 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
case SpvOpGroupNonUniformBitwiseXor: case SpvOpGroupNonUniformBitwiseXor:
case SpvOpGroupNonUniformLogicalAnd: case SpvOpGroupNonUniformLogicalAnd:
case SpvOpGroupNonUniformLogicalOr: case SpvOpGroupNonUniformLogicalOr:
case SpvOpGroupNonUniformLogicalXor: {
nir_op reduction_op;
switch (opcode) {
case SpvOpGroupNonUniformIAdd:
reduction_op = nir_op_iadd;
break;
case SpvOpGroupNonUniformFAdd:
reduction_op = nir_op_fadd;
break;
case SpvOpGroupNonUniformIMul:
reduction_op = nir_op_imul;
break;
case SpvOpGroupNonUniformFMul:
reduction_op = nir_op_fmul;
break;
case SpvOpGroupNonUniformSMin:
reduction_op = nir_op_imin;
break;
case SpvOpGroupNonUniformUMin:
reduction_op = nir_op_umin;
break;
case SpvOpGroupNonUniformFMin:
reduction_op = nir_op_fmin;
break;
case SpvOpGroupNonUniformSMax:
reduction_op = nir_op_imax;
break;
case SpvOpGroupNonUniformUMax:
reduction_op = nir_op_umax;
break;
case SpvOpGroupNonUniformFMax:
reduction_op = nir_op_fmax;
break;
case SpvOpGroupNonUniformBitwiseAnd:
case SpvOpGroupNonUniformLogicalAnd:
reduction_op = nir_op_iand;
break;
case SpvOpGroupNonUniformBitwiseOr:
case SpvOpGroupNonUniformLogicalOr:
reduction_op = nir_op_ior;
break;
case SpvOpGroupNonUniformBitwiseXor:
case SpvOpGroupNonUniformLogicalXor: case SpvOpGroupNonUniformLogicalXor:
reduction_op = nir_op_ixor;
break;
default:
unreachable("Invalid reduction operation");
}
nir_intrinsic_op op;
unsigned cluster_size = 0;
switch ((SpvGroupOperation)w[4]) {
case SpvGroupOperationReduce:
op = nir_intrinsic_reduce;
break;
case SpvGroupOperationInclusiveScan:
op = nir_intrinsic_inclusive_scan;
break;
case SpvGroupOperationExclusiveScan:
op = nir_intrinsic_exclusive_scan;
break;
case SpvGroupOperationClusteredReduce:
op = nir_intrinsic_reduce;
assert(count == 7);
cluster_size = vtn_constant_value(b, w[6])->values[0].u32[0];
break;
default:
unreachable("Invalid group operation");
}
vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
NULL, reduction_op, cluster_size);
break;
}
default: default:
unreachable("Invalid SPIR-V opcode"); unreachable("Invalid SPIR-V opcode");
} }