spirv: Implement SPV_EXT_replicated_composites

Reviewed-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29509>
This commit is contained in:
Faith Ekstrand
2024-04-02 17:04:52 -05:00
committed by Marge Bot
parent fff42bcc66
commit c452143024

View File

@@ -135,6 +135,7 @@ static const struct spirv_capabilities implemented_capabilities = {
.RayTracingKHR = true,
.RayTracingPositionFetchKHR = true,
.RayTraversalPrimitiveCullingKHR = true,
.ReplicatedCompositesEXT = true,
.RoundingModeRTE = true,
.RoundingModeRTZ = true,
.RuntimeDescriptorArrayEXT = true,
@@ -2338,29 +2339,52 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
}
case SpvOpSpecConstantComposite:
case SpvOpConstantComposite: {
unsigned elem_count = count - 3;
unsigned expected_length = val->type->base_type == vtn_base_type_cooperative_matrix ?
case SpvOpConstantComposite:
case SpvOpConstantCompositeReplicateEXT:
case SpvOpSpecConstantCompositeReplicateEXT: {
const unsigned elem_count =
val->type->base_type == vtn_base_type_cooperative_matrix ?
1 : val->type->length;
vtn_fail_if(elem_count != expected_length,
"%s has %u constituents, expected %u",
spirv_op_to_string(opcode), elem_count, expected_length);
nir_constant **elems = ralloc_array(b, nir_constant *, elem_count);
val->is_undef_constant = true;
for (unsigned i = 0; i < elem_count; i++) {
struct vtn_value *elem_val = vtn_untyped_value(b, w[i + 3]);
if (opcode == SpvOpConstantCompositeReplicateEXT ||
opcode == SpvOpSpecConstantCompositeReplicateEXT) {
struct vtn_value *elem_val = vtn_untyped_value(b, w[3]);
if (elem_val->value_type == vtn_value_type_constant) {
elems[i] = elem_val->constant;
val->is_undef_constant = val->is_undef_constant &&
elem_val->is_undef_constant;
elems[0] = elem_val->constant;
val->is_undef_constant = false;
} else {
vtn_fail_if(elem_val->value_type != vtn_value_type_undef,
"only constants or undefs allowed for "
"SpvOpConstantComposite");
"only constants or undefs allowed for %s",
spirv_op_to_string(opcode));
/* to make it easier, just insert a NULL constant for now */
elems[i] = vtn_null_constant(b, elem_val->type);
elems[0] = vtn_null_constant(b, elem_val->type);
val->is_undef_constant = true;
}
for (unsigned i = 1; i < elem_count; i++)
elems[i] = elems[0];
} else {
vtn_fail_if(elem_count != count - 3,
"%s has %u constituents, expected %u",
spirv_op_to_string(opcode), count - 3, elem_count);
val->is_undef_constant = true;
for (unsigned i = 0; i < elem_count; i++) {
struct vtn_value *elem_val = vtn_untyped_value(b, w[i + 3]);
if (elem_val->value_type == vtn_value_type_constant) {
elems[i] = elem_val->constant;
val->is_undef_constant = val->is_undef_constant &&
elem_val->is_undef_constant;
} else {
vtn_fail_if(elem_val->value_type != vtn_value_type_undef,
"only constants or undefs allowed for %s",
spirv_op_to_string(opcode));
/* to make it easier, just insert a NULL constant for now */
elems[i] = vtn_null_constant(b, elem_val->type);
}
}
}
@@ -4538,7 +4562,8 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
w + 5);
break;
case SpvOpCompositeConstruct: {
case SpvOpCompositeConstruct:
case SpvOpCompositeConstructReplicateEXT: {
unsigned elems = count - 3;
assume(elems >= 1);
if (type->base_type == vtn_base_type_cooperative_matrix) {
@@ -4547,21 +4572,35 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
nir_cmat_construct(&b->nb, &mat->def, vtn_get_nir_ssa(b, w[3]));
vtn_set_ssa_value_var(b, ssa, mat->var);
} else if (glsl_type_is_vector_or_scalar(type->type)) {
nir_def *srcs[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i < elems; i++) {
srcs[i] = vtn_get_nir_ssa(b, w[3 + i]);
vtn_assert(glsl_get_bit_size(type->type) == srcs[i]->bit_size);
if (opcode == SpvOpCompositeConstructReplicateEXT) {
nir_def *src = vtn_get_nir_ssa(b, w[3]);
vtn_assert(glsl_get_bit_size(type->type) == src->bit_size);
unsigned swiz[NIR_MAX_VEC_COMPONENTS] = { 0, };
ssa->def = nir_swizzle(&b->nb, src, swiz,
glsl_get_vector_elements(type->type));
} else {
nir_def *srcs[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i < elems; i++) {
srcs[i] = vtn_get_nir_ssa(b, w[3 + i]);
vtn_assert(glsl_get_bit_size(type->type) == srcs[i]->bit_size);
}
ssa->def =
vtn_vector_construct(b, glsl_get_vector_elements(type->type),
elems, srcs);
}
ssa->def =
vtn_vector_construct(b, glsl_get_vector_elements(type->type),
elems, srcs);
} else {
vtn_fail_if(elems != type->length,
"%s has %u constituents, expected %u",
spirv_op_to_string(opcode), elems, type->length);
ssa->elems = vtn_alloc_array(b, struct vtn_ssa_value *, elems);
for (unsigned i = 0; i < elems; i++)
ssa->elems[i] = vtn_ssa_value(b, w[3 + i]);
ssa->elems = vtn_alloc_array(b, struct vtn_ssa_value *, type->length);
if (opcode == SpvOpCompositeConstructReplicateEXT) {
struct vtn_ssa_value *elem = vtn_ssa_value(b, w[3]);
for (unsigned i = 0; i < type->length; i++)
ssa->elems[i] = elem;
} else {
vtn_fail_if(elems != type->length,
"%s has %u constituents, expected %u",
spirv_op_to_string(opcode), elems, type->length);
for (unsigned i = 0; i < elems; i++)
ssa->elems[i] = vtn_ssa_value(b, w[3 + i]);
}
}
break;
}
@@ -5562,11 +5601,13 @@ vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvOpConstantFalse:
case SpvOpConstant:
case SpvOpConstantComposite:
case SpvOpConstantCompositeReplicateEXT:
case SpvOpConstantNull:
case SpvOpSpecConstantTrue:
case SpvOpSpecConstantFalse:
case SpvOpSpecConstant:
case SpvOpSpecConstantComposite:
case SpvOpSpecConstantCompositeReplicateEXT:
case SpvOpSpecConstantOp:
vtn_handle_constant(b, opcode, w, count);
break;
@@ -6322,6 +6363,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvOpVectorInsertDynamic:
case SpvOpVectorShuffle:
case SpvOpCompositeConstruct:
case SpvOpCompositeConstructReplicateEXT:
case SpvOpCompositeExtract:
case SpvOpCompositeInsert:
case SpvOpCopyLogical: