diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 5e6bbe8f953..a3e26586bed 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -120,7 +120,6 @@ typedef struct nir_ssa_def *lds_addr_gs_scratch; unsigned lds_bytes_per_gs_out_vertex; unsigned lds_offs_primflags; - bool found_out_vtxcnt[4]; bool output_compile_time_known; bool streamout_enabled; /* 32 bit outputs */ @@ -131,6 +130,9 @@ typedef struct nir_variable *output_vars_16bit_lo[16][4]; gs_output_info output_info_16bit_hi[16]; gs_output_info output_info_16bit_lo[16]; + /* Count per stream. */ + nir_ssa_def *vertex_count[4]; + nir_ssa_def *primitive_count[4]; } lower_ngg_gs_state; /* LDS layout of Mesh Shader workgroup info. */ @@ -2390,7 +2392,7 @@ ngg_gs_clear_primflags(nir_builder *b, nir_ssa_def *num_vertices, unsigned strea } static void -ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s) +ngg_gs_shader_query(nir_builder *b, lower_ngg_gs_state *s) { bool has_gen_prim_query = s->options->has_gen_prim_query; bool has_pipeline_stats_query = s->options->gfx_level < GFX11; @@ -2415,25 +2417,36 @@ ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_st } nir_if *if_shader_query = nir_push_if(b, shader_query_enabled); - nir_ssa_def *num_prims_in_wave = NULL; + + nir_ssa_def *active_threads_mask = nir_ballot(b, 1, s->options->wave_size, nir_imm_bool(b, true)); + nir_ssa_def *num_active_threads = nir_bit_count(b, active_threads_mask); /* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives. * GS emits points, line strips or triangle strips. * Real primitives are points, lines or triangles. */ - if (nir_src_is_const(intrin->src[0]) && nir_src_is_const(intrin->src[1])) { - unsigned gs_vtx_cnt = nir_src_as_uint(intrin->src[0]); - unsigned gs_prm_cnt = nir_src_as_uint(intrin->src[1]); - unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u); - nir_ssa_def *num_threads = - nir_bit_count(b, nir_ballot(b, 1, s->options->wave_size, nir_imm_bool(b, true))); - num_prims_in_wave = nir_imul_imm(b, num_threads, total_prm_cnt); - } else { - nir_ssa_def *gs_vtx_cnt = intrin->src[0].ssa; - nir_ssa_def *prm_cnt = intrin->src[1].ssa; - if (s->num_vertices_per_primitive > 1) - prm_cnt = nir_iadd(b, nir_imul_imm(b, prm_cnt, -1u * (s->num_vertices_per_primitive - 1)), gs_vtx_cnt); - num_prims_in_wave = nir_reduce(b, prm_cnt, .reduction_op = nir_op_iadd); + nir_ssa_def *num_prims_in_wave[4] = {0}; + u_foreach_bit (i, b->shader->info.gs.active_stream_mask) { + assert(s->vertex_count[i] && s->primitive_count[i]); + + nir_ssa_scalar vtx_cnt = nir_get_ssa_scalar(s->vertex_count[i], 0); + nir_ssa_scalar prm_cnt = nir_get_ssa_scalar(s->primitive_count[i], 0); + + if (nir_ssa_scalar_is_const(vtx_cnt) && nir_ssa_scalar_is_const(prm_cnt)) { + unsigned gs_vtx_cnt = nir_ssa_scalar_as_uint(vtx_cnt); + unsigned gs_prm_cnt = nir_ssa_scalar_as_uint(prm_cnt); + unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u); + if (total_prm_cnt == 0) + continue; + + num_prims_in_wave[i] = nir_imul_imm(b, num_active_threads, total_prm_cnt); + } else { + nir_ssa_def *gs_vtx_cnt = vtx_cnt.def; + nir_ssa_def *gs_prm_cnt = prm_cnt.def; + if (s->num_vertices_per_primitive > 1) + gs_prm_cnt = nir_iadd(b, nir_imul_imm(b, gs_prm_cnt, -1u * (s->num_vertices_per_primitive - 1)), gs_vtx_cnt); + num_prims_in_wave[i] = nir_reduce(b, gs_prm_cnt, .reduction_op = nir_op_iadd); + } } /* Store the query result to query result using an atomic add. */ @@ -2442,8 +2455,20 @@ ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_st if (has_pipeline_stats_query) { nir_if *if_pipeline_query = nir_push_if(b, pipeline_query_enabled); { + nir_ssa_def *count = NULL; + /* Add all streams' number to the same counter. */ - nir_atomic_add_gs_emit_prim_count_amd(b, num_prims_in_wave); + for (int i = 0; i < 4; i++) { + if (num_prims_in_wave[i]) { + if (count) + count = nir_iadd(b, count, num_prims_in_wave[i]); + else + count = num_prims_in_wave[i]; + } + } + + if (count) + nir_atomic_add_gs_emit_prim_count_amd(b, count); } nir_pop_if(b, if_pipeline_query); } @@ -2452,8 +2477,10 @@ ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_st nir_if *if_prim_gen_query = nir_push_if(b, prim_gen_query_enabled); { /* Add to the counter for this stream. */ - nir_atomic_add_gen_prim_count_amd( - b, num_prims_in_wave, .stream_id = nir_intrinsic_stream_id(intrin)); + for (int i = 0; i < 4; i++) { + if (num_prims_in_wave[i]) + nir_atomic_add_gen_prim_count_amd(b, num_prims_in_wave[i], .stream_id = i); + } } nir_pop_if(b, if_prim_gen_query); } @@ -2708,13 +2735,13 @@ lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr return true; } - s->found_out_vtxcnt[stream] = true; + s->vertex_count[stream] = intrin->src[0].ssa; + s->primitive_count[stream] = intrin->src[1].ssa; /* Clear the primitive flags of non-emitted vertices */ if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out) ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s); - ngg_gs_shader_query(b, intrin, s); nir_instr_remove(&intrin->instr); return true; } @@ -3344,13 +3371,18 @@ ac_nir_lower_ngg_gs(nir_shader *shader, const ac_nir_lower_ngg_options *options) /* Lower the GS intrinsics */ lower_ngg_gs_intrinsics(shader, &state); - b->cursor = nir_after_cf_list(&impl->body); - if (!state.found_out_vtxcnt[0]) { + if (!state.vertex_count[0]) { fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU."); abort(); } + /* Emit shader queries */ + b->cursor = nir_after_cf_list(&if_gs_thread->then_list); + ngg_gs_shader_query(b, &state); + + b->cursor = nir_after_cf_list(&impl->body); + /* Emit the finale sequence */ ngg_gs_finale(b, &state); nir_validate_shader(shader, "after emitting NGG GS");