ac/nir: Remove byte permute from prefix sum of the repack sequence.
The byte-permute instruction v_perm_b32 is not exposed by older LLVM releases (only available on LLVM 13 and later), therefore a new sequence is needed which we can use with these LLVM versions too. The prefix sum is replaced by two alternatives: 1. For GPUs that support v_dot, we shift 0x01 to the wanted byte positions and then use v_dot to sum the results. 2. For older GPUs (Navi 10), we simply shift out the unwanted bytes and use v_sad_u8 to produce the sum. Signed-off-by: Timur Kristóf <timur.kristof@gmail.com> Acked-by: Marek Olšák <marek.olsak@amd.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12786>
This commit is contained in:
@@ -133,6 +133,70 @@ typedef struct {
|
||||
nir_ssa_def *repacked_invocation_index;
|
||||
} wg_repack_result;
|
||||
|
||||
/**
|
||||
* Computes a horizontal sum of 8-bit packed values loaded from LDS.
|
||||
*
|
||||
* Each lane N will sum packed bytes 0 to N-1.
|
||||
* We only care about the results from up to wave_id+1 lanes.
|
||||
* (Other lanes are not deactivated but their calculation is not used.)
|
||||
*/
|
||||
static nir_ssa_def *
|
||||
summarize_repack(nir_builder *b, nir_ssa_def *packed_counts, unsigned num_lds_dwords)
|
||||
{
|
||||
/* We'll use shift to filter out the bytes not needed by the current lane.
|
||||
*
|
||||
* Need to shift by: num_lds_dwords * 4 - lane_id (in bytes).
|
||||
* However, two shifts are needed because one can't go all the way,
|
||||
* so the shift amount is half that (and in bits).
|
||||
*
|
||||
* When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
|
||||
* This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions,
|
||||
* therefore v_dot can get rid of the unneeded values.
|
||||
* This sequence is preferable because it better hides the latency of the LDS.
|
||||
*
|
||||
* If the v_dot instruction can't be used, we left-shift the packed bytes.
|
||||
* This will shift out the unneeded bytes and shift in zeroes instead,
|
||||
* then we sum them using v_sad_u8.
|
||||
*/
|
||||
|
||||
nir_ssa_def *lane_id = nir_load_subgroup_invocation(b);
|
||||
nir_ssa_def *shift = nir_iadd_imm_nuw(b, nir_imul_imm(b, lane_id, -4u), num_lds_dwords * 16);
|
||||
bool use_dot = b->shader->options->has_dot_4x8;
|
||||
|
||||
if (num_lds_dwords == 1) {
|
||||
nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int(b, 0x01010101), shift), shift);
|
||||
|
||||
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
|
||||
nir_ssa_def *packed = nir_build_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
|
||||
/* Horizontally add the packed bytes. */
|
||||
if (use_dot) {
|
||||
return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
|
||||
} else {
|
||||
nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift);
|
||||
return nir_sad_u8x4(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
}
|
||||
} else if (num_lds_dwords == 2) {
|
||||
nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift), shift);
|
||||
|
||||
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
|
||||
nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
|
||||
/* Horizontally add the packed bytes. */
|
||||
if (use_dot) {
|
||||
nir_ssa_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
|
||||
return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum);
|
||||
} else {
|
||||
nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift);
|
||||
nir_ssa_def *sum = nir_sad_u8x4(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
return nir_sad_u8x4(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum);
|
||||
}
|
||||
} else {
|
||||
unreachable("Unimplemented NGG wave count");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Repacks invocations in the current workgroup to eliminate gaps between them.
|
||||
*
|
||||
@@ -208,41 +272,7 @@ repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool,
|
||||
*/
|
||||
|
||||
nir_ssa_def *num_waves = nir_build_load_num_subgroups(b);
|
||||
|
||||
/* sel = 0x01010101 * lane_id + 0x03020100 */
|
||||
nir_ssa_def *lane_id = nir_load_subgroup_invocation(b);
|
||||
nir_ssa_def *packed_id = nir_build_byte_permute_amd(b, nir_imm_int(b, 0), lane_id, nir_imm_int(b, 0));
|
||||
nir_ssa_def *sel = nir_iadd_imm_nuw(b, packed_id, 0x03020100);
|
||||
nir_ssa_def *sum = NULL;
|
||||
|
||||
if (num_lds_dwords == 1) {
|
||||
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
|
||||
nir_ssa_def *packed_dw = nir_build_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
|
||||
/* Use byte-permute to filter out the bytes not needed by the current lane. */
|
||||
nir_ssa_def *filtered_packed = nir_build_byte_permute_amd(b, packed_dw, nir_imm_int(b, 0), sel);
|
||||
|
||||
/* Horizontally add the packed bytes. */
|
||||
sum = nir_sad_u8x4(b, filtered_packed, nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
} else if (num_lds_dwords == 2) {
|
||||
/* Create selectors for the byte-permutes below. */
|
||||
nir_ssa_def *dw0_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x44443210), nir_imm_int(b, 0x4));
|
||||
nir_ssa_def *dw1_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x32100000), nir_imm_int(b, 0x4));
|
||||
|
||||
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
|
||||
nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
|
||||
/* Use byte-permute to filter out the bytes not needed by the current lane. */
|
||||
nir_ssa_def *filtered_packed_dw0 = nir_build_byte_permute_amd(b, packed_dw0, nir_imm_int(b, 0), dw0_selector);
|
||||
nir_ssa_def *filtered_packed_dw1 = nir_build_byte_permute_amd(b, packed_dw1, nir_imm_int(b, 0), dw1_selector);
|
||||
|
||||
/* Horizontally add the packed bytes. */
|
||||
sum = nir_sad_u8x4(b, filtered_packed_dw0, nir_imm_int(b, 0), nir_imm_int(b, 0));
|
||||
sum = nir_sad_u8x4(b, filtered_packed_dw1, nir_imm_int(b, 0), sum);
|
||||
} else {
|
||||
unreachable("Unimplemented NGG wave count");
|
||||
}
|
||||
nir_ssa_def *sum = summarize_repack(b, packed_counts, num_lds_dwords);
|
||||
|
||||
nir_ssa_def *wg_repacked_index_base = nir_build_read_invocation(b, sum, wave_id);
|
||||
nir_ssa_def *wg_num_repacked_invocations = nir_build_read_invocation(b, sum, num_waves);
|
||||
|
Reference in New Issue
Block a user