radv: Use correct writemask for cooperative matrix ordering.

Not expecting this to actually fix anything externally visible,
but reduces some invalid usage when the resulting vector is
not 16 elements long (e.g. the C/result matrix).

Fixes: 9df4703fbb ("radv: Add cooperative matrix lowering.")
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26768>
(cherry picked from commit 07ad6fd34a)
This commit is contained in:
Bas Nieuwenhuizen
2023-12-20 00:19:55 +01:00
committed by Eric Engestrom
parent a23408d57a
commit 953da13070
2 changed files with 13 additions and 9 deletions

View File

@@ -264,7 +264,7 @@
"description": "radv: Use correct writemask for cooperative matrix ordering.",
"nominated": true,
"nomination_type": 1,
"resolution": 0,
"resolution": 1,
"main_sha": null,
"because_sha": "9df4703fbb59d1295a9d3daf6320f329c9de2d66",
"notes": null

View File

@@ -181,7 +181,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_def *elem = intr->src[1].ssa;
nir_def *r = nir_vector_insert(&b, src1, elem, index);
nir_store_deref(&b, dst_deref, r, 0xffff);
nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
nir_instr_remove(instr);
progress = true;
break;
@@ -193,7 +193,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, wave_size));
nir_store_deref(&b, dst_deref, r, 0xffff);
nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components));
nir_instr_remove(instr);
progress = true;
break;
@@ -253,7 +253,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
}
nir_def *mat = nir_vec(&b, vars, length);
nir_store_deref(&b, dst_deref, mat, 0xffff);
nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components));
nir_instr_remove(instr);
progress = true;
break;
@@ -332,7 +332,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr),
.cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr));
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
nir_component_mask(ret->num_components));
nir_instr_remove(instr);
progress = true;
break;
@@ -366,7 +367,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
ret = nir_vec(&b, components, ret->num_components * 2);
}
nir_store_deref(&b, dst_deref, ret, 0xffff);
nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components));
nir_instr_remove(instr);
progress = true;
break;
@@ -375,7 +376,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
nir_op op = nir_intrinsic_alu_op(intr);
nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
nir_component_mask(ret->num_components));
nir_instr_remove(instr);
progress = true;
break;
@@ -385,14 +387,16 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_def *src2 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa);
nir_op op = nir_intrinsic_alu_op(intr);
nir_def *ret = nir_build_alu2(&b, op, src1, src2);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
nir_component_mask(ret->num_components));
nir_instr_remove(instr);
progress = true;
break;
}
case nir_intrinsic_cmat_bitcast: {
nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1, 0xffff);
nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1,
nir_component_mask(src1->num_components));
nir_instr_remove(instr);
progress = true;
break;