ac/nir/ngg: merge multi stream gs shader queries

Before this commit each stream will emit a query block, now
we merge them to a single block.

Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Acked-by: Marek Olšák <marek.olsak@amd.com>
Signed-off-by: Qiang Yu <yuq825@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20074>
This commit is contained in:
Qiang Yu
2022-11-30 15:22:29 +08:00
committed by Marge Bot
parent b7b91ae51e
commit 2fb1097bac

View File

@@ -120,7 +120,6 @@ typedef struct
nir_ssa_def *lds_addr_gs_scratch;
unsigned lds_bytes_per_gs_out_vertex;
unsigned lds_offs_primflags;
bool found_out_vtxcnt[4];
bool output_compile_time_known;
bool streamout_enabled;
/* 32 bit outputs */
@@ -131,6 +130,9 @@ typedef struct
nir_variable *output_vars_16bit_lo[16][4];
gs_output_info output_info_16bit_hi[16];
gs_output_info output_info_16bit_lo[16];
/* Count per stream. */
nir_ssa_def *vertex_count[4];
nir_ssa_def *primitive_count[4];
} lower_ngg_gs_state;
/* LDS layout of Mesh Shader workgroup info. */
@@ -2390,7 +2392,7 @@ ngg_gs_clear_primflags(nir_builder *b, nir_ssa_def *num_vertices, unsigned strea
}
static void
ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
ngg_gs_shader_query(nir_builder *b, lower_ngg_gs_state *s)
{
bool has_gen_prim_query = s->options->has_gen_prim_query;
bool has_pipeline_stats_query = s->options->gfx_level < GFX11;
@@ -2415,25 +2417,36 @@ ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_st
}
nir_if *if_shader_query = nir_push_if(b, shader_query_enabled);
nir_ssa_def *num_prims_in_wave = NULL;
nir_ssa_def *active_threads_mask = nir_ballot(b, 1, s->options->wave_size, nir_imm_bool(b, true));
nir_ssa_def *num_active_threads = nir_bit_count(b, active_threads_mask);
/* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives.
* GS emits points, line strips or triangle strips.
* Real primitives are points, lines or triangles.
*/
if (nir_src_is_const(intrin->src[0]) && nir_src_is_const(intrin->src[1])) {
unsigned gs_vtx_cnt = nir_src_as_uint(intrin->src[0]);
unsigned gs_prm_cnt = nir_src_as_uint(intrin->src[1]);
unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u);
nir_ssa_def *num_threads =
nir_bit_count(b, nir_ballot(b, 1, s->options->wave_size, nir_imm_bool(b, true)));
num_prims_in_wave = nir_imul_imm(b, num_threads, total_prm_cnt);
} else {
nir_ssa_def *gs_vtx_cnt = intrin->src[0].ssa;
nir_ssa_def *prm_cnt = intrin->src[1].ssa;
if (s->num_vertices_per_primitive > 1)
prm_cnt = nir_iadd(b, nir_imul_imm(b, prm_cnt, -1u * (s->num_vertices_per_primitive - 1)), gs_vtx_cnt);
num_prims_in_wave = nir_reduce(b, prm_cnt, .reduction_op = nir_op_iadd);
nir_ssa_def *num_prims_in_wave[4] = {0};
u_foreach_bit (i, b->shader->info.gs.active_stream_mask) {
assert(s->vertex_count[i] && s->primitive_count[i]);
nir_ssa_scalar vtx_cnt = nir_get_ssa_scalar(s->vertex_count[i], 0);
nir_ssa_scalar prm_cnt = nir_get_ssa_scalar(s->primitive_count[i], 0);
if (nir_ssa_scalar_is_const(vtx_cnt) && nir_ssa_scalar_is_const(prm_cnt)) {
unsigned gs_vtx_cnt = nir_ssa_scalar_as_uint(vtx_cnt);
unsigned gs_prm_cnt = nir_ssa_scalar_as_uint(prm_cnt);
unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u);
if (total_prm_cnt == 0)
continue;
num_prims_in_wave[i] = nir_imul_imm(b, num_active_threads, total_prm_cnt);
} else {
nir_ssa_def *gs_vtx_cnt = vtx_cnt.def;
nir_ssa_def *gs_prm_cnt = prm_cnt.def;
if (s->num_vertices_per_primitive > 1)
gs_prm_cnt = nir_iadd(b, nir_imul_imm(b, gs_prm_cnt, -1u * (s->num_vertices_per_primitive - 1)), gs_vtx_cnt);
num_prims_in_wave[i] = nir_reduce(b, gs_prm_cnt, .reduction_op = nir_op_iadd);
}
}
/* Store the query result to query result using an atomic add. */
@@ -2442,8 +2455,20 @@ ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_st
if (has_pipeline_stats_query) {
nir_if *if_pipeline_query = nir_push_if(b, pipeline_query_enabled);
{
nir_ssa_def *count = NULL;
/* Add all streams' number to the same counter. */
nir_atomic_add_gs_emit_prim_count_amd(b, num_prims_in_wave);
for (int i = 0; i < 4; i++) {
if (num_prims_in_wave[i]) {
if (count)
count = nir_iadd(b, count, num_prims_in_wave[i]);
else
count = num_prims_in_wave[i];
}
}
if (count)
nir_atomic_add_gs_emit_prim_count_amd(b, count);
}
nir_pop_if(b, if_pipeline_query);
}
@@ -2452,8 +2477,10 @@ ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_st
nir_if *if_prim_gen_query = nir_push_if(b, prim_gen_query_enabled);
{
/* Add to the counter for this stream. */
nir_atomic_add_gen_prim_count_amd(
b, num_prims_in_wave, .stream_id = nir_intrinsic_stream_id(intrin));
for (int i = 0; i < 4; i++) {
if (num_prims_in_wave[i])
nir_atomic_add_gen_prim_count_amd(b, num_prims_in_wave[i], .stream_id = i);
}
}
nir_pop_if(b, if_prim_gen_query);
}
@@ -2708,13 +2735,13 @@ lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr
return true;
}
s->found_out_vtxcnt[stream] = true;
s->vertex_count[stream] = intrin->src[0].ssa;
s->primitive_count[stream] = intrin->src[1].ssa;
/* Clear the primitive flags of non-emitted vertices */
if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out)
ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s);
ngg_gs_shader_query(b, intrin, s);
nir_instr_remove(&intrin->instr);
return true;
}
@@ -3344,13 +3371,18 @@ ac_nir_lower_ngg_gs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
/* Lower the GS intrinsics */
lower_ngg_gs_intrinsics(shader, &state);
b->cursor = nir_after_cf_list(&impl->body);
if (!state.found_out_vtxcnt[0]) {
if (!state.vertex_count[0]) {
fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU.");
abort();
}
/* Emit shader queries */
b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
ngg_gs_shader_query(b, &state);
b->cursor = nir_after_cf_list(&impl->body);
/* Emit the finale sequence */
ngg_gs_finale(b, &state);
nir_validate_shader(shader, "after emitting NGG GS");