diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index d64ab18cf86..ea38861eb13 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -135,6 +135,7 @@ static const struct spirv_capabilities implemented_capabilities = { .RayTracingKHR = true, .RayTracingPositionFetchKHR = true, .RayTraversalPrimitiveCullingKHR = true, + .ReplicatedCompositesEXT = true, .RoundingModeRTE = true, .RoundingModeRTZ = true, .RuntimeDescriptorArrayEXT = true, @@ -2338,29 +2339,52 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, } case SpvOpSpecConstantComposite: - case SpvOpConstantComposite: { - unsigned elem_count = count - 3; - unsigned expected_length = val->type->base_type == vtn_base_type_cooperative_matrix ? + case SpvOpConstantComposite: + case SpvOpConstantCompositeReplicateEXT: + case SpvOpSpecConstantCompositeReplicateEXT: { + const unsigned elem_count = + val->type->base_type == vtn_base_type_cooperative_matrix ? 1 : val->type->length; - vtn_fail_if(elem_count != expected_length, - "%s has %u constituents, expected %u", - spirv_op_to_string(opcode), elem_count, expected_length); nir_constant **elems = ralloc_array(b, nir_constant *, elem_count); - val->is_undef_constant = true; - for (unsigned i = 0; i < elem_count; i++) { - struct vtn_value *elem_val = vtn_untyped_value(b, w[i + 3]); + if (opcode == SpvOpConstantCompositeReplicateEXT || + opcode == SpvOpSpecConstantCompositeReplicateEXT) { + struct vtn_value *elem_val = vtn_untyped_value(b, w[3]); if (elem_val->value_type == vtn_value_type_constant) { - elems[i] = elem_val->constant; - val->is_undef_constant = val->is_undef_constant && - elem_val->is_undef_constant; + elems[0] = elem_val->constant; + val->is_undef_constant = false; } else { vtn_fail_if(elem_val->value_type != vtn_value_type_undef, - "only constants or undefs allowed for " - "SpvOpConstantComposite"); + "only constants or undefs allowed for %s", + spirv_op_to_string(opcode)); /* to make it easier, just insert a NULL constant for now */ - elems[i] = vtn_null_constant(b, elem_val->type); + elems[0] = vtn_null_constant(b, elem_val->type); + val->is_undef_constant = true; + } + + for (unsigned i = 1; i < elem_count; i++) + elems[i] = elems[0]; + } else { + vtn_fail_if(elem_count != count - 3, + "%s has %u constituents, expected %u", + spirv_op_to_string(opcode), count - 3, elem_count); + + val->is_undef_constant = true; + for (unsigned i = 0; i < elem_count; i++) { + struct vtn_value *elem_val = vtn_untyped_value(b, w[i + 3]); + + if (elem_val->value_type == vtn_value_type_constant) { + elems[i] = elem_val->constant; + val->is_undef_constant = val->is_undef_constant && + elem_val->is_undef_constant; + } else { + vtn_fail_if(elem_val->value_type != vtn_value_type_undef, + "only constants or undefs allowed for %s", + spirv_op_to_string(opcode)); + /* to make it easier, just insert a NULL constant for now */ + elems[i] = vtn_null_constant(b, elem_val->type); + } } } @@ -4538,7 +4562,8 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode, w + 5); break; - case SpvOpCompositeConstruct: { + case SpvOpCompositeConstruct: + case SpvOpCompositeConstructReplicateEXT: { unsigned elems = count - 3; assume(elems >= 1); if (type->base_type == vtn_base_type_cooperative_matrix) { @@ -4547,21 +4572,35 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode, nir_cmat_construct(&b->nb, &mat->def, vtn_get_nir_ssa(b, w[3])); vtn_set_ssa_value_var(b, ssa, mat->var); } else if (glsl_type_is_vector_or_scalar(type->type)) { - nir_def *srcs[NIR_MAX_VEC_COMPONENTS]; - for (unsigned i = 0; i < elems; i++) { - srcs[i] = vtn_get_nir_ssa(b, w[3 + i]); - vtn_assert(glsl_get_bit_size(type->type) == srcs[i]->bit_size); + if (opcode == SpvOpCompositeConstructReplicateEXT) { + nir_def *src = vtn_get_nir_ssa(b, w[3]); + vtn_assert(glsl_get_bit_size(type->type) == src->bit_size); + unsigned swiz[NIR_MAX_VEC_COMPONENTS] = { 0, }; + ssa->def = nir_swizzle(&b->nb, src, swiz, + glsl_get_vector_elements(type->type)); + } else { + nir_def *srcs[NIR_MAX_VEC_COMPONENTS]; + for (unsigned i = 0; i < elems; i++) { + srcs[i] = vtn_get_nir_ssa(b, w[3 + i]); + vtn_assert(glsl_get_bit_size(type->type) == srcs[i]->bit_size); + } + ssa->def = + vtn_vector_construct(b, glsl_get_vector_elements(type->type), + elems, srcs); } - ssa->def = - vtn_vector_construct(b, glsl_get_vector_elements(type->type), - elems, srcs); } else { - vtn_fail_if(elems != type->length, - "%s has %u constituents, expected %u", - spirv_op_to_string(opcode), elems, type->length); - ssa->elems = vtn_alloc_array(b, struct vtn_ssa_value *, elems); - for (unsigned i = 0; i < elems; i++) - ssa->elems[i] = vtn_ssa_value(b, w[3 + i]); + ssa->elems = vtn_alloc_array(b, struct vtn_ssa_value *, type->length); + if (opcode == SpvOpCompositeConstructReplicateEXT) { + struct vtn_ssa_value *elem = vtn_ssa_value(b, w[3]); + for (unsigned i = 0; i < type->length; i++) + ssa->elems[i] = elem; + } else { + vtn_fail_if(elems != type->length, + "%s has %u constituents, expected %u", + spirv_op_to_string(opcode), elems, type->length); + for (unsigned i = 0; i < elems; i++) + ssa->elems[i] = vtn_ssa_value(b, w[3 + i]); + } } break; } @@ -5562,11 +5601,13 @@ vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpConstantFalse: case SpvOpConstant: case SpvOpConstantComposite: + case SpvOpConstantCompositeReplicateEXT: case SpvOpConstantNull: case SpvOpSpecConstantTrue: case SpvOpSpecConstantFalse: case SpvOpSpecConstant: case SpvOpSpecConstantComposite: + case SpvOpSpecConstantCompositeReplicateEXT: case SpvOpSpecConstantOp: vtn_handle_constant(b, opcode, w, count); break; @@ -6322,6 +6363,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpVectorInsertDynamic: case SpvOpVectorShuffle: case SpvOpCompositeConstruct: + case SpvOpCompositeConstructReplicateEXT: case SpvOpCompositeExtract: case SpvOpCompositeInsert: case SpvOpCopyLogical: