From 8fa46b31a89fde179d87f0b714bc882ebfa43b0d Mon Sep 17 00:00:00 2001 From: Faith Ekstrand Date: Tue, 2 Apr 2024 18:49:22 -0500 Subject: [PATCH] spirv: Handle constant cooperative matrices in OpCompositeExtract Fixes: b98f87612bc1 ("spirv: Implement SPV_KHR_cooperative_matrix") Reviewed-by: Konstantin Seurer Part-of: --- src/compiler/spirv/spirv_to_nir.c | 49 ++++++++++++++++++------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index e9c252e1d8b..f2d12e967f6 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -2466,31 +2466,38 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, int elem = -1; const struct vtn_type *type = comp->type; for (unsigned i = deref_start; i < count; i++) { - vtn_fail_if(w[i] > type->length, - "%uth index of %s is %u but the type has only " - "%u elements", i - deref_start, - spirv_op_to_string(opcode), w[i], type->length); + if (type->base_type == vtn_base_type_cooperative_matrix) { + /* Cooperative matrices are always scalar constants. We don't + * care about the index w[i] because it's always replicated. + */ + type = type->component_type; + } else { + vtn_fail_if(w[i] > type->length, + "%uth index of %s is %u but the type has only " + "%u elements", i - deref_start, + spirv_op_to_string(opcode), w[i], type->length); - switch (type->base_type) { - case vtn_base_type_vector: - elem = w[i]; - type = type->array_element; - break; + switch (type->base_type) { + case vtn_base_type_vector: + elem = w[i]; + type = type->array_element; + break; - case vtn_base_type_matrix: - case vtn_base_type_array: - c = &(*c)->elements[w[i]]; - type = type->array_element; - break; + case vtn_base_type_matrix: + case vtn_base_type_array: + c = &(*c)->elements[w[i]]; + type = type->array_element; + break; - case vtn_base_type_struct: - c = &(*c)->elements[w[i]]; - type = type->members[w[i]]; - break; + case vtn_base_type_struct: + c = &(*c)->elements[w[i]]; + type = type->members[w[i]]; + break; - default: - vtn_fail("%s must only index into composite types", - spirv_op_to_string(opcode)); + default: + vtn_fail("%s must only index into composite types", + spirv_op_to_string(opcode)); + } } }