diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c index e71c76f87ad..914b1f04bdf 100644 --- a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -257,10 +257,64 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, ptr_comp_width * ptr_num_comps), glsl_base_type_get_bit_size(desc->element_type)); - if ((nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) == - (desc->use != GLSL_CMAT_USE_B)) { - stride = nir_udiv_imm(b, stride, packing_factor); + /* The data that will be packed is in successive columns for A and + * accumulator matrices. The data that will be packed for B matrices is in + * successive rows. + */ + const unsigned cols = + desc->use != GLSL_CMAT_USE_B ? desc->cols / packing_factor : desc->cols; + nir_def *invocation = nir_load_subgroup_invocation(b); + nir_def *invocation_div_cols = nir_udiv_imm(b, invocation, cols); + nir_def *invocation_mod_cols = nir_umod_imm(b, invocation, cols); + + nir_def *i_stride; + + const bool memory_layout_matches_register_layout = + (nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) == + (desc->use != GLSL_CMAT_USE_B); + + if (memory_layout_matches_register_layout) { + /* In the row-major arrangement, data is loaded a dword at a time + * instead of a single element at a time. For this reason the stride is + * divided by the packing factor. + */ + i_stride = nir_udiv_imm(b, stride, packing_factor); + } else { + /* In the column-major arrangement, data is loaded a single element at a + * time. Because the data elements are transposed, the step direction + * that moves a single (packed) element in the row-major arrangement has + * to explicitly step over the packing factor count of elements. For + * this reason the stride is multiplied by the packing factor. + * + * NOTE: The unscaled stride is also still needed when stepping from one + * packed element to the next. This occurs in the for-j loop below. + */ + i_stride = nir_imul_imm(b, stride, packing_factor); + } + + nir_def *base_offset; + nir_def *i_step; + + if (nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) { + base_offset = nir_iadd(b, + nir_imul(b, + invocation_div_cols, + i_stride), + invocation_mod_cols); + + i_step = nir_imul_imm(b, i_stride, state->subgroup_size / cols); + } else { + base_offset = nir_iadd(b, + nir_imul(b, + invocation_mod_cols, + i_stride), + invocation_div_cols); + + i_step = nir_imm_int(b, state->subgroup_size / cols); + } + + if (memory_layout_matches_register_layout) { const struct glsl_type *element_type = glsl_scalar_type(glsl_get_base_type(slice->type)); @@ -268,30 +322,8 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, element_type, glsl_get_bit_size(element_type) / 8); - nir_def *invocation = nir_load_subgroup_invocation(b); - nir_def *base_offset; - nir_def *step; - - if (desc->use != GLSL_CMAT_USE_B) { - base_offset = nir_iadd(b, - nir_imul(b, - nir_udiv_imm(b, invocation, 8), - stride), - nir_umod_imm(b, invocation, 8)); - - step = nir_imul_imm(b, stride, state->subgroup_size / 8); - } else { - base_offset = nir_iadd(b, - nir_imul(b, - nir_umod_imm(b, invocation, 8), - stride), - nir_udiv_imm(b, invocation, 8)); - - step = nir_imm_int(b, state->subgroup_size / 8); - } - for (unsigned i = 0; i < num_components; i++) { - nir_def *offset = nir_imul_imm(b, step, i); + nir_def *offset = nir_imul_imm(b, i_step, i); nir_deref_instr *memory_deref = nir_build_deref_ptr_as_array(b, pointer, @@ -316,45 +348,19 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, element_type, element_stride); - nir_def *invocation_div_8 = nir_udiv_imm(b, nir_load_subgroup_invocation(b), 8); - nir_def *invocation_mod_8 = nir_umod_imm(b, nir_load_subgroup_invocation(b), 8); - - nir_def *packed_stride = nir_imul_imm(b, stride, packing_factor); - for (unsigned i = 0; i < num_components; i++) { - const unsigned i_offset = i * (state->subgroup_size / 8); + nir_def *i_offset = nir_imul_imm(b, i_step, i); nir_def *v[4]; for (unsigned j = 0; j < packing_factor; j++) { - nir_def *j_offset = nir_imul_imm(b, stride, j); - nir_def *offset; - - if (desc->use != GLSL_CMAT_USE_B) { - offset = nir_iadd(b, - nir_iadd(b, - nir_imul(b, - invocation_mod_8, - packed_stride), - invocation_div_8), - nir_iadd_imm(b, j_offset, i_offset)); - } else { - offset = nir_iadd(b, - nir_iadd(b, - nir_imul(b, - invocation_div_8, - packed_stride), - invocation_mod_8), - nir_iadd(b, - nir_imul_imm(b, - packed_stride, - i_offset), - j_offset)); - } + nir_def *offset = nir_iadd(b, nir_imul_imm(b, stride, j), i_offset); nir_deref_instr *memory_deref = nir_build_deref_ptr_as_array(b, pointer, nir_i2iN(b, - offset, + nir_iadd(b, + base_offset, + offset), pointer->def.bit_size)); if (load) {