nir: Add ability to count emitted GS primitives.

Add an option to nir_lower_gs_intrinsics which tells it to track
the number of emitted primitives, not just vertices. Additionally,
also make it per-stream.

Also rename the set_vertex_count intrinsic to
set_vertex_and_primitive_count.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6964>
This commit is contained in:
Timur Kristóf
2020-06-08 12:16:13 +02:00
parent 73dd86c421
commit 2be99012e9
9 changed files with 77 additions and 28 deletions

View File

@@ -8022,7 +8022,7 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
bld.sopp(aco_opcode::s_sendmsg, bld.m0(ctx->gs_wave_id), -1, sendmsg_gs(true, false, stream));
break;
}
case nir_intrinsic_set_vertex_count: {
case nir_intrinsic_set_vertex_and_primitive_count: {
/* unused, the HW keeps track of this for us */
break;
}

View File

@@ -554,7 +554,7 @@ radv_shader_compile_to_nir(struct radv_device *device,
nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
if (nir->info.stage == MESA_SHADER_GEOMETRY)
nir_lower_gs_intrinsics(nir, true);
nir_lower_gs_intrinsics(nir, nir_lower_gs_intrinsics_per_stream);
static const nir_lower_tex_options tex_options = {
.lower_txp = ~0,

View File

@@ -4726,7 +4726,12 @@ typedef enum {
bool nir_lower_to_source_mods(nir_shader *shader, nir_lower_to_source_mods_flags options);
bool nir_lower_gs_intrinsics(nir_shader *shader, bool per_stream);
typedef enum {
nir_lower_gs_intrinsics_per_stream = 1 << 0,
nir_lower_gs_intrinsics_count_primitives = 1 << 1,
} nir_lower_gs_intrinsics_flags;
bool nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags options);
typedef unsigned (*nir_lower_bit_size_callback)(const nir_alu_instr *, void *);

View File

@@ -38,9 +38,9 @@ as_intrinsic(nir_instr *instr, nir_intrinsic_op op)
}
static nir_intrinsic_instr *
as_set_vertex_count(nir_instr *instr)
as_set_vertex_and_primitive_count(nir_instr *instr)
{
return as_intrinsic(instr, nir_intrinsic_set_vertex_count);
return as_intrinsic(instr, nir_intrinsic_set_vertex_and_primitive_count);
}
/**
@@ -59,14 +59,14 @@ nir_gs_count_vertices(const nir_shader *shader)
if (!function->impl)
continue;
/* set_vertex_count intrinsics only appear in predecessors of the
/* set_vertex_and_primitive_count intrinsics only appear in predecessors of the
* end block. So we don't need to walk all of them.
*/
set_foreach(function->impl->end_block->predecessors, entry) {
nir_block *block = (nir_block *) entry->key;
nir_foreach_instr_reverse(instr, block) {
nir_intrinsic_instr *intrin = as_set_vertex_count(instr);
nir_intrinsic_instr *intrin = as_set_vertex_and_primitive_count(instr);
if (!intrin)
continue;
@@ -77,7 +77,7 @@ nir_gs_count_vertices(const nir_shader *shader)
if (count == -1)
count = nir_src_as_int(intrin->src[0]);
/* We've found contradictory set_vertex_count intrinsics.
/* We've found contradictory set_vertex_and_primitive_count intrinsics.
* This can happen if there are early-returns in main() and
* different paths emit different numbers of vertices.
*/

View File

@@ -352,7 +352,8 @@ intrinsic("end_primitive", indices=[STREAM_ID])
# unsigned integer source.
intrinsic("emit_vertex_with_counter", src_comp=[1], indices=[STREAM_ID])
intrinsic("end_primitive_with_counter", src_comp=[1], indices=[STREAM_ID])
intrinsic("set_vertex_count", src_comp=[1])
# Contains the final total vertex and primitive counts
intrinsic("set_vertex_and_primitive_count", src_comp=[1, 1], indices=[STREAM_ID])
# Atomic counters
#

View File

@@ -57,6 +57,9 @@
struct state {
nir_builder *builder;
nir_variable *vertex_count_vars[NIR_MAX_XFB_STREAMS];
nir_variable *primitive_count_vars[NIR_MAX_XFB_STREAMS];
bool per_stream;
bool count_prims;
bool progress;
};
@@ -98,7 +101,7 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
/* Increment the vertex count by 1 */
nir_store_var(b, state->vertex_count_vars[stream],
nir_iadd(b, count, nir_imm_int(b, 1)),
nir_iadd_imm(b, count, 1),
0x1); /* .x */
nir_pop_if(b, NULL);
@@ -128,6 +131,14 @@ rewrite_end_primitive(nir_intrinsic_instr *intrin, struct state *state)
lowered->src[0] = nir_src_for_ssa(count);
nir_builder_instr_insert(b, &lowered->instr);
if (state->count_prims) {
/* Increment the primitive count by 1 */
nir_ssa_def *prim_cnt = nir_load_var(b, state->primitive_count_vars[stream]);
nir_store_var(b, state->primitive_count_vars[stream],
nir_iadd_imm(b, prim_cnt, 1),
0x1); /* .x */
}
nir_instr_remove(&intrin->instr);
state->progress = true;
@@ -158,11 +169,11 @@ rewrite_intrinsics(nir_block *block, struct state *state)
}
/**
* Add a set_vertex_count intrinsic at the end of the program
* (representing the final vertex count).
* Add a set_vertex_and_primitive_count intrinsic at the end of the program
* (representing the final total vertex and primitive count).
*/
static void
append_set_vertex_count(nir_block *end_block, struct state *state)
append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state)
{
nir_builder *b = state->builder;
nir_shader *shader = state->builder->shader;
@@ -174,21 +185,44 @@ append_set_vertex_count(nir_block *end_block, struct state *state)
nir_block *pred = (nir_block *) entry->key;
b->cursor = nir_after_block_before_jump(pred);
nir_ssa_def *count = nir_load_var(b, state->vertex_count_vars[0]);
for (unsigned stream = 0; stream < NIR_MAX_XFB_STREAMS; ++stream) {
/* When it's not per-stream, we only need to write one variable. */
if (!state->per_stream && stream != 0)
continue;
/* When it's per-stream, make sure not to use inactive streams. */
if (state->per_stream && !(shader->info.gs.active_stream_mask & (1 << stream)))
continue;
nir_intrinsic_instr *set_vertex_count =
nir_intrinsic_instr_create(shader, nir_intrinsic_set_vertex_count);
set_vertex_count->src[0] = nir_src_for_ssa(count);
nir_ssa_def *vtx_cnt = nir_load_var(b, state->vertex_count_vars[stream]);
nir_ssa_def *prim_cnt;
nir_builder_instr_insert(b, &set_vertex_count->instr);
if (state->count_prims)
prim_cnt = nir_load_var(b, state->primitive_count_vars[stream]);
else
prim_cnt = nir_ssa_undef(b, 1, 32);
nir_intrinsic_instr *set_cnt_intrin =
nir_intrinsic_instr_create(shader,
nir_intrinsic_set_vertex_and_primitive_count);
nir_intrinsic_set_stream_id(set_cnt_intrin, stream);
set_cnt_intrin->src[0] = nir_src_for_ssa(vtx_cnt);
set_cnt_intrin->src[1] = nir_src_for_ssa(prim_cnt);
nir_builder_instr_insert(b, &set_cnt_intrin->instr);
}
}
}
bool
nir_lower_gs_intrinsics(nir_shader *shader, bool per_stream)
nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags options)
{
bool per_stream = options & nir_lower_gs_intrinsics_per_stream;
bool count_primitives = options & nir_lower_gs_intrinsics_count_primitives;
struct state state;
state.progress = false;
state.count_prims = count_primitives;
state.per_stream = per_stream;
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
assert(impl);
@@ -197,8 +231,8 @@ nir_lower_gs_intrinsics(nir_shader *shader, bool per_stream)
nir_builder_init(&b, impl);
state.builder = &b;
/* Create the counter variables */
b.cursor = nir_before_cf_list(&impl->body);
for (unsigned i = 0; i < NIR_MAX_XFB_STREAMS; i++) {
if (per_stream && !(shader->info.gs.active_stream_mask & (1 << i)))
continue;
@@ -208,12 +242,22 @@ nir_lower_gs_intrinsics(nir_shader *shader, bool per_stream)
nir_local_variable_create(impl, glsl_uint_type(), "vertex_count");
/* initialize to 0 */
nir_store_var(&b, state.vertex_count_vars[i], nir_imm_int(&b, 0), 0x1);
if (count_primitives) {
state.primitive_count_vars[i] =
nir_local_variable_create(impl, glsl_uint_type(), "primitive_count");
/* initialize to 1 */
nir_store_var(&b, state.primitive_count_vars[i], nir_imm_int(&b, 1), 0x1);
}
} else {
/* If per_stream is false, we only have one counter which we want to use
* for all streams. Duplicate the counter pointer so all streams use the
* same counter.
/* If per_stream is false, we only have one counter of each kind which we
* want to use for all streams. Duplicate the counter pointers so all
* streams use the same counters.
*/
state.vertex_count_vars[i] = state.vertex_count_vars[0];
if (count_primitives)
state.primitive_count_vars[i] = state.primitive_count_vars[0];
}
}
@@ -221,8 +265,7 @@ nir_lower_gs_intrinsics(nir_shader *shader, bool per_stream)
rewrite_intrinsics(block, &state);
/* This only works because we have a single main() function. */
if (!per_stream)
append_set_vertex_count(impl->end_block, &state);
append_set_vertex_and_primitive_count(impl->end_block, &state);
nir_metadata_preserve(impl, 0);

View File

@@ -3149,7 +3149,7 @@ fs_visitor::nir_emit_gs_intrinsic(const fs_builder &bld,
emit_gs_end_primitive(instr->src[0]);
break;
case nir_intrinsic_set_vertex_count:
case nir_intrinsic_set_vertex_and_primitive_count:
bld.MOV(this->final_gs_vertex_count, get_nir_src(instr->src[0]));
break;

View File

@@ -692,7 +692,7 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir,
}
if (nir->info.stage == MESA_SHADER_GEOMETRY)
OPT(nir_lower_gs_intrinsics, false);
OPT(nir_lower_gs_intrinsics, 0);
/* See also brw_nir_trig_workarounds.py */
if (compiler->precise_trig &&

View File

@@ -78,7 +78,7 @@ vec4_gs_visitor::nir_emit_intrinsic(nir_intrinsic_instr *instr)
gs_end_primitive();
break;
case nir_intrinsic_set_vertex_count:
case nir_intrinsic_set_vertex_and_primitive_count:
this->vertex_count =
retype(get_nir_src(instr->src[0], 1), BRW_REGISTER_TYPE_UD);
break;