intel/brw: Major rework of lower_cmat_load_store

The original goal was to get rid of a bunch of the magic constants
sprinkled through the function. Once I did that, I realized that there
was a lot my symmertry between the row-major and column-major paths
possible.

It's +6 lines of code, but about 15 of those lines are comments
explaining things that were not obvious in the original code.

v2: Save duplicated condition in a variable with a meaningful
name. Suggested by Caio.

Reviewed-by: Caio Oliveira <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/28834>
This commit is contained in:
Ian Romanick
2024-04-15 10:40:51 -07:00
parent ea6e10c0b2
commit 7a773ac53e

View File

@@ -257,10 +257,64 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin,
ptr_comp_width * ptr_num_comps),
glsl_base_type_get_bit_size(desc->element_type));
if ((nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) ==
(desc->use != GLSL_CMAT_USE_B)) {
stride = nir_udiv_imm(b, stride, packing_factor);
/* The data that will be packed is in successive columns for A and
* accumulator matrices. The data that will be packed for B matrices is in
* successive rows.
*/
const unsigned cols =
desc->use != GLSL_CMAT_USE_B ? desc->cols / packing_factor : desc->cols;
nir_def *invocation = nir_load_subgroup_invocation(b);
nir_def *invocation_div_cols = nir_udiv_imm(b, invocation, cols);
nir_def *invocation_mod_cols = nir_umod_imm(b, invocation, cols);
nir_def *i_stride;
const bool memory_layout_matches_register_layout =
(nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) ==
(desc->use != GLSL_CMAT_USE_B);
if (memory_layout_matches_register_layout) {
/* In the row-major arrangement, data is loaded a dword at a time
* instead of a single element at a time. For this reason the stride is
* divided by the packing factor.
*/
i_stride = nir_udiv_imm(b, stride, packing_factor);
} else {
/* In the column-major arrangement, data is loaded a single element at a
* time. Because the data elements are transposed, the step direction
* that moves a single (packed) element in the row-major arrangement has
* to explicitly step over the packing factor count of elements. For
* this reason the stride is multiplied by the packing factor.
*
* NOTE: The unscaled stride is also still needed when stepping from one
* packed element to the next. This occurs in the for-j loop below.
*/
i_stride = nir_imul_imm(b, stride, packing_factor);
}
nir_def *base_offset;
nir_def *i_step;
if (nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
base_offset = nir_iadd(b,
nir_imul(b,
invocation_div_cols,
i_stride),
invocation_mod_cols);
i_step = nir_imul_imm(b, i_stride, state->subgroup_size / cols);
} else {
base_offset = nir_iadd(b,
nir_imul(b,
invocation_mod_cols,
i_stride),
invocation_div_cols);
i_step = nir_imm_int(b, state->subgroup_size / cols);
}
if (memory_layout_matches_register_layout) {
const struct glsl_type *element_type =
glsl_scalar_type(glsl_get_base_type(slice->type));
@@ -268,30 +322,8 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin,
element_type,
glsl_get_bit_size(element_type) / 8);
nir_def *invocation = nir_load_subgroup_invocation(b);
nir_def *base_offset;
nir_def *step;
if (desc->use != GLSL_CMAT_USE_B) {
base_offset = nir_iadd(b,
nir_imul(b,
nir_udiv_imm(b, invocation, 8),
stride),
nir_umod_imm(b, invocation, 8));
step = nir_imul_imm(b, stride, state->subgroup_size / 8);
} else {
base_offset = nir_iadd(b,
nir_imul(b,
nir_umod_imm(b, invocation, 8),
stride),
nir_udiv_imm(b, invocation, 8));
step = nir_imm_int(b, state->subgroup_size / 8);
}
for (unsigned i = 0; i < num_components; i++) {
nir_def *offset = nir_imul_imm(b, step, i);
nir_def *offset = nir_imul_imm(b, i_step, i);
nir_deref_instr *memory_deref =
nir_build_deref_ptr_as_array(b, pointer,
@@ -316,45 +348,19 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin,
pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, element_type,
element_stride);
nir_def *invocation_div_8 = nir_udiv_imm(b, nir_load_subgroup_invocation(b), 8);
nir_def *invocation_mod_8 = nir_umod_imm(b, nir_load_subgroup_invocation(b), 8);
nir_def *packed_stride = nir_imul_imm(b, stride, packing_factor);
for (unsigned i = 0; i < num_components; i++) {
const unsigned i_offset = i * (state->subgroup_size / 8);
nir_def *i_offset = nir_imul_imm(b, i_step, i);
nir_def *v[4];
for (unsigned j = 0; j < packing_factor; j++) {
nir_def *j_offset = nir_imul_imm(b, stride, j);
nir_def *offset;
if (desc->use != GLSL_CMAT_USE_B) {
offset = nir_iadd(b,
nir_iadd(b,
nir_imul(b,
invocation_mod_8,
packed_stride),
invocation_div_8),
nir_iadd_imm(b, j_offset, i_offset));
} else {
offset = nir_iadd(b,
nir_iadd(b,
nir_imul(b,
invocation_div_8,
packed_stride),
invocation_mod_8),
nir_iadd(b,
nir_imul_imm(b,
packed_stride,
i_offset),
j_offset));
}
nir_def *offset = nir_iadd(b, nir_imul_imm(b, stride, j), i_offset);
nir_deref_instr *memory_deref =
nir_build_deref_ptr_as_array(b, pointer,
nir_i2iN(b,
offset,
nir_iadd(b,
base_offset,
offset),
pointer->def.bit_size));
if (load) {