nir: Add ability to count emitted GS vertices per primitive.

Add an option to nir_lower_gs_intrinsics so that it can also track
the number of emitted vertices per primitive, not just the total
vertex count.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6964>
This commit is contained in:
Timur Kristóf
2020-06-16 18:58:39 +02:00
parent 2be99012e9
commit c977c369d3
3 changed files with 53 additions and 6 deletions

View File

@@ -4729,6 +4729,7 @@ bool nir_lower_to_source_mods(nir_shader *shader, nir_lower_to_source_mods_flags
typedef enum {
nir_lower_gs_intrinsics_per_stream = 1 << 0,
nir_lower_gs_intrinsics_count_primitives = 1 << 1,
nir_lower_gs_intrinsics_count_vertices_per_primitive = 1 << 2,
} nir_lower_gs_intrinsics_flags;
bool nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags options);

View File

@@ -348,11 +348,13 @@ intrinsic("end_primitive", indices=[STREAM_ID])
# Alternatively, drivers may implement these intrinsics, and use
# nir_lower_gs_intrinsics() to convert from the basic intrinsics.
#
# These maintain a count of the number of vertices emitted, as an additional
# unsigned integer source.
intrinsic("emit_vertex_with_counter", src_comp=[1], indices=[STREAM_ID])
intrinsic("end_primitive_with_counter", src_comp=[1], indices=[STREAM_ID])
# Contains the final total vertex and primitive counts
# These contain two additional unsigned integer sources:
# 1. The total number of vertices emitted so far.
# 2. The number of vertices emitted for the current primitive
# so far if we're counting, otherwise undef.
intrinsic("emit_vertex_with_counter", src_comp=[1, 1], indices=[STREAM_ID])
intrinsic("end_primitive_with_counter", src_comp=[1, 1], indices=[STREAM_ID])
# Contains the final total vertex and primitive counts in the current GS thread.
intrinsic("set_vertex_and_primitive_count", src_comp=[1, 1], indices=[STREAM_ID])
# Atomic counters

View File

@@ -57,9 +57,11 @@
struct state {
nir_builder *builder;
nir_variable *vertex_count_vars[NIR_MAX_XFB_STREAMS];
nir_variable *vtxcnt_per_prim_vars[NIR_MAX_XFB_STREAMS];
nir_variable *primitive_count_vars[NIR_MAX_XFB_STREAMS];
bool per_stream;
bool count_prims;
bool count_vtx_per_prim;
bool progress;
};
@@ -67,8 +69,9 @@ struct state {
* Replace emit_vertex intrinsics with:
*
* if (vertex_count < max_vertices) {
* emit_vertex_with_counter vertex_count ...
* emit_vertex_with_counter vertex_count, vertex_count_per_primitive (optional) ...
* vertex_count += 1
* vertex_count_per_primitive += 1
* }
*/
static void
@@ -81,6 +84,12 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
b->cursor = nir_before_instr(&intrin->instr);
assert(state->vertex_count_vars[stream] != NULL);
nir_ssa_def *count = nir_load_var(b, state->vertex_count_vars[stream]);
nir_ssa_def *count_per_primitive;
if (state->count_vtx_per_prim)
count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]);
else
count_per_primitive = nir_ssa_undef(b, 1, 32);
nir_ssa_def *max_vertices =
nir_imm_int(b, b->shader->info.gs.vertices_out);
@@ -97,6 +106,7 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
nir_intrinsic_emit_vertex_with_counter);
nir_intrinsic_set_stream_id(lowered, stream);
lowered->src[0] = nir_src_for_ssa(count);
lowered->src[1] = nir_src_for_ssa(count_per_primitive);
nir_builder_instr_insert(b, &lowered->instr);
/* Increment the vertex count by 1 */
@@ -104,6 +114,15 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
nir_iadd_imm(b, count, 1),
0x1); /* .x */
if (state->count_vtx_per_prim) {
/* Increment the per-primitive vertex count by 1 */
nir_variable *var = state->vtxcnt_per_prim_vars[stream];
nir_ssa_def *vtx_per_prim_cnt = nir_load_var(b, var);
nir_store_var(b, var,
nir_iadd_imm(b, vtx_per_prim_cnt, 1),
0x1); /* .x */
}
nir_pop_if(b, NULL);
nir_instr_remove(&intrin->instr);
@@ -123,12 +142,19 @@ rewrite_end_primitive(nir_intrinsic_instr *intrin, struct state *state)
b->cursor = nir_before_instr(&intrin->instr);
assert(state->vertex_count_vars[stream] != NULL);
nir_ssa_def *count = nir_load_var(b, state->vertex_count_vars[stream]);
nir_ssa_def *count_per_primitive;
if (state->count_vtx_per_prim)
count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]);
else
count_per_primitive = nir_ssa_undef(b, count->num_components, count->bit_size);
nir_intrinsic_instr *lowered =
nir_intrinsic_instr_create(b->shader,
nir_intrinsic_end_primitive_with_counter);
nir_intrinsic_set_stream_id(lowered, stream);
lowered->src[0] = nir_src_for_ssa(count);
lowered->src[1] = nir_src_for_ssa(count_per_primitive);
nir_builder_instr_insert(b, &lowered->instr);
if (state->count_prims) {
@@ -139,6 +165,13 @@ rewrite_end_primitive(nir_intrinsic_instr *intrin, struct state *state)
0x1); /* .x */
}
if (state->count_vtx_per_prim) {
/* Store 0 to per-primitive vertex count */
nir_store_var(b, state->vtxcnt_per_prim_vars[stream],
nir_imm_int(b, 0),
0x1); /* .x */
}
nir_instr_remove(&intrin->instr);
state->progress = true;
@@ -218,10 +251,13 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
{
bool per_stream = options & nir_lower_gs_intrinsics_per_stream;
bool count_primitives = options & nir_lower_gs_intrinsics_count_primitives;
bool count_vtx_per_prim =
options & nir_lower_gs_intrinsics_count_vertices_per_primitive;
struct state state;
state.progress = false;
state.count_prims = count_primitives;
state.count_vtx_per_prim = count_vtx_per_prim;
state.per_stream = per_stream;
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
@@ -249,6 +285,12 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
/* initialize to 1 */
nir_store_var(&b, state.primitive_count_vars[i], nir_imm_int(&b, 1), 0x1);
}
if (count_vtx_per_prim) {
state.vtxcnt_per_prim_vars[i] =
nir_local_variable_create(impl, glsl_uint_type(), "vertices_per_primitive");
/* initialize to 0 */
nir_store_var(&b, state.vtxcnt_per_prim_vars[i], nir_imm_int(&b, 0), 0x1);
}
} else {
/* If per_stream is false, we only have one counter of each kind which we
* want to use for all streams. Duplicate the counter pointers so all
@@ -258,6 +300,8 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
if (count_primitives)
state.primitive_count_vars[i] = state.primitive_count_vars[0];
if (count_vtx_per_prim)
state.vtxcnt_per_prim_vars[i] = state.vtxcnt_per_prim_vars[0];
}
}