diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 09a5d6b607a..01661409db1 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -929,8 +929,11 @@ static void emit_alu(struct ntv_context *ctx, nir_alu_instr *alu) { SpvId src[nir_op_infos[alu->op].num_inputs]; - for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) + unsigned in_bit_sizes[nir_op_infos[alu->op].num_inputs]; + for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) { src[i] = get_alu_src(ctx, alu, i); + in_bit_sizes[i] = nir_src_bit_size(alu->src[i].src); + } SpvId dest_type = get_dest_type(ctx, &alu->dest.dest, nir_op_infos[alu->op].output_type); @@ -1154,51 +1157,67 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu) case nir_op_bany_fnequal2: case nir_op_bany_fnequal3: - case nir_op_bany_fnequal4: + case nir_op_bany_fnequal4: { assert(nir_op_infos[alu->op].num_inputs == 2); assert(alu_instr_src_components(alu, 0) == alu_instr_src_components(alu, 1)); - result = emit_binop(ctx, SpvOpFOrdNotEqual, + assert(in_bit_sizes[0] == in_bit_sizes[1]); + /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */ + SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpFOrdNotEqual; + result = emit_binop(ctx, op, get_bvec_type(ctx, alu_instr_src_components(alu, 0)), src[0], src[1]); result = emit_unop(ctx, SpvOpAny, dest_type, result); break; + } case nir_op_ball_fequal2: case nir_op_ball_fequal3: - case nir_op_ball_fequal4: + case nir_op_ball_fequal4: { assert(nir_op_infos[alu->op].num_inputs == 2); assert(alu_instr_src_components(alu, 0) == alu_instr_src_components(alu, 1)); - result = emit_binop(ctx, SpvOpFOrdEqual, + assert(in_bit_sizes[0] == in_bit_sizes[1]); + /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */ + SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpFOrdEqual; + result = emit_binop(ctx, op, get_bvec_type(ctx, alu_instr_src_components(alu, 0)), src[0], src[1]); result = emit_unop(ctx, SpvOpAll, dest_type, result); break; + } case nir_op_bany_inequal2: case nir_op_bany_inequal3: - case nir_op_bany_inequal4: + case nir_op_bany_inequal4: { assert(nir_op_infos[alu->op].num_inputs == 2); assert(alu_instr_src_components(alu, 0) == alu_instr_src_components(alu, 1)); - result = emit_binop(ctx, SpvOpINotEqual, + assert(in_bit_sizes[0] == in_bit_sizes[1]); + /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */ + SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpINotEqual; + result = emit_binop(ctx, op, get_bvec_type(ctx, alu_instr_src_components(alu, 0)), src[0], src[1]); result = emit_unop(ctx, SpvOpAny, dest_type, result); break; + } case nir_op_ball_iequal2: case nir_op_ball_iequal3: - case nir_op_ball_iequal4: + case nir_op_ball_iequal4: { assert(nir_op_infos[alu->op].num_inputs == 2); assert(alu_instr_src_components(alu, 0) == alu_instr_src_components(alu, 1)); - result = emit_binop(ctx, SpvOpIEqual, + assert(in_bit_sizes[0] == in_bit_sizes[1]); + /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */ + SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpIEqual; + result = emit_binop(ctx, op, get_bvec_type(ctx, alu_instr_src_components(alu, 0)), src[0], src[1]); result = emit_unop(ctx, SpvOpAll, dest_type, result); break; + } case nir_op_vec2: case nir_op_vec3: