diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 6e63e153fb8..a3d9416bb19 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -133,6 +133,70 @@ typedef struct { nir_ssa_def *repacked_invocation_index; } wg_repack_result; +/** + * Computes a horizontal sum of 8-bit packed values loaded from LDS. + * + * Each lane N will sum packed bytes 0 to N-1. + * We only care about the results from up to wave_id+1 lanes. + * (Other lanes are not deactivated but their calculation is not used.) + */ +static nir_ssa_def * +summarize_repack(nir_builder *b, nir_ssa_def *packed_counts, unsigned num_lds_dwords) +{ + /* We'll use shift to filter out the bytes not needed by the current lane. + * + * Need to shift by: num_lds_dwords * 4 - lane_id (in bytes). + * However, two shifts are needed because one can't go all the way, + * so the shift amount is half that (and in bits). + * + * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes. + * This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions, + * therefore v_dot can get rid of the unneeded values. + * This sequence is preferable because it better hides the latency of the LDS. + * + * If the v_dot instruction can't be used, we left-shift the packed bytes. + * This will shift out the unneeded bytes and shift in zeroes instead, + * then we sum them using v_sad_u8. + */ + + nir_ssa_def *lane_id = nir_load_subgroup_invocation(b); + nir_ssa_def *shift = nir_iadd_imm_nuw(b, nir_imul_imm(b, lane_id, -4u), num_lds_dwords * 16); + bool use_dot = b->shader->options->has_dot_4x8; + + if (num_lds_dwords == 1) { + nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int(b, 0x01010101), shift), shift); + + /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */ + nir_ssa_def *packed = nir_build_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0)); + + /* Horizontally add the packed bytes. */ + if (use_dot) { + return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0)); + } else { + nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift); + return nir_sad_u8x4(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0)); + } + } else if (num_lds_dwords == 2) { + nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift), shift); + + /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */ + nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0)); + nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0)); + + /* Horizontally add the packed bytes. */ + if (use_dot) { + nir_ssa_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0)); + return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum); + } else { + nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift); + nir_ssa_def *sum = nir_sad_u8x4(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0)); + return nir_sad_u8x4(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum); + } + } else { + unreachable("Unimplemented NGG wave count"); + } +} + /** * Repacks invocations in the current workgroup to eliminate gaps between them. * @@ -208,41 +272,7 @@ repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool, */ nir_ssa_def *num_waves = nir_build_load_num_subgroups(b); - - /* sel = 0x01010101 * lane_id + 0x03020100 */ - nir_ssa_def *lane_id = nir_load_subgroup_invocation(b); - nir_ssa_def *packed_id = nir_build_byte_permute_amd(b, nir_imm_int(b, 0), lane_id, nir_imm_int(b, 0)); - nir_ssa_def *sel = nir_iadd_imm_nuw(b, packed_id, 0x03020100); - nir_ssa_def *sum = NULL; - - if (num_lds_dwords == 1) { - /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */ - nir_ssa_def *packed_dw = nir_build_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0)); - - /* Use byte-permute to filter out the bytes not needed by the current lane. */ - nir_ssa_def *filtered_packed = nir_build_byte_permute_amd(b, packed_dw, nir_imm_int(b, 0), sel); - - /* Horizontally add the packed bytes. */ - sum = nir_sad_u8x4(b, filtered_packed, nir_imm_int(b, 0), nir_imm_int(b, 0)); - } else if (num_lds_dwords == 2) { - /* Create selectors for the byte-permutes below. */ - nir_ssa_def *dw0_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x44443210), nir_imm_int(b, 0x4)); - nir_ssa_def *dw1_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x32100000), nir_imm_int(b, 0x4)); - - /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */ - nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0)); - nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0)); - - /* Use byte-permute to filter out the bytes not needed by the current lane. */ - nir_ssa_def *filtered_packed_dw0 = nir_build_byte_permute_amd(b, packed_dw0, nir_imm_int(b, 0), dw0_selector); - nir_ssa_def *filtered_packed_dw1 = nir_build_byte_permute_amd(b, packed_dw1, nir_imm_int(b, 0), dw1_selector); - - /* Horizontally add the packed bytes. */ - sum = nir_sad_u8x4(b, filtered_packed_dw0, nir_imm_int(b, 0), nir_imm_int(b, 0)); - sum = nir_sad_u8x4(b, filtered_packed_dw1, nir_imm_int(b, 0), sum); - } else { - unreachable("Unimplemented NGG wave count"); - } + nir_ssa_def *sum = summarize_repack(b, packed_counts, num_lds_dwords); nir_ssa_def *wg_repacked_index_base = nir_build_read_invocation(b, sum, wave_id); nir_ssa_def *wg_num_repacked_invocations = nir_build_read_invocation(b, sum, num_waves);