diff --git a/.pick_status.json b/.pick_status.json index 33b7de9321e..08e8966ac91 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -264,7 +264,7 @@ "description": "radv: Use correct writemask for cooperative matrix ordering.", "nominated": true, "nomination_type": 1, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": "9df4703fbb59d1295a9d3daf6320f329c9de2d66", "notes": null diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index d81231b0137..e882100e141 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -181,7 +181,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_def *elem = intr->src[1].ssa; nir_def *r = nir_vector_insert(&b, src1, elem, index); - nir_store_deref(&b, dst_deref, r, 0xffff); + nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components)); nir_instr_remove(instr); progress = true; break; @@ -193,7 +193,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, wave_size)); - nir_store_deref(&b, dst_deref, r, 0xffff); + nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components)); nir_instr_remove(instr); progress = true; break; @@ -253,7 +253,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) } nir_def *mat = nir_vec(&b, vars, length); - nir_store_deref(&b, dst_deref, mat, 0xffff); + nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components)); nir_instr_remove(instr); progress = true; break; @@ -332,7 +332,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr), .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr)); - nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff); + nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, + nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; break; @@ -366,7 +367,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) ret = nir_vec(&b, components, ret->num_components * 2); } - nir_store_deref(&b, dst_deref, ret, 0xffff); + nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; break; @@ -375,7 +376,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); nir_op op = nir_intrinsic_alu_op(intr); nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa); - nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff); + nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, + nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; break; @@ -385,14 +387,16 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_def *src2 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa); nir_op op = nir_intrinsic_alu_op(intr); nir_def *ret = nir_build_alu2(&b, op, src1, src2); - nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff); + nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, + nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; break; } case nir_intrinsic_cmat_bitcast: { nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); - nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1, 0xffff); + nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1, + nir_component_mask(src1->num_components)); nir_instr_remove(instr); progress = true; break;