nir/lower_gs_intrinsics: Count decomposed primitives too

We need both: decomposed primitives for transform feedback and regular
primitives for the sizing the index buffer.

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Antonino Maniscalco <antonino.maniscalco@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26056>
This commit is contained in:
Alyssa Rosenzweig
2023-08-14 16:45:44 -04:00
committed by Marge Bot
parent 0a35aa3a2b
commit b65636ca40
4 changed files with 92 additions and 20 deletions

View File

@@ -6030,6 +6030,7 @@ typedef enum {
nir_lower_gs_intrinsics_count_vertices_per_primitive = 1 << 2,
nir_lower_gs_intrinsics_overwrite_incomplete = 1 << 3,
nir_lower_gs_intrinsics_always_end_primitive = 1 << 4,
nir_lower_gs_intrinsics_count_decomposed_primitives = 1 << 5,
} nir_lower_gs_intrinsics_flags;
bool nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags options);

View File

@@ -536,15 +536,19 @@ intrinsic("end_primitive", indices=[STREAM_ID])
# Alternatively, drivers may implement these intrinsics, and use
# nir_lower_gs_intrinsics() to convert from the basic intrinsics.
#
# These contain three additional unsigned integer sources:
# These contain four additional unsigned integer sources:
# 1. The total number of vertices emitted so far.
# 2. The number of vertices emitted for the current primitive
# so far if we're counting, otherwise undef.
# 3. The total number of primitives emitted so far.
intrinsic("emit_vertex_with_counter", src_comp=[1, 1, 1], indices=[STREAM_ID])
intrinsic("end_primitive_with_counter", src_comp=[1, 1, 1], indices=[STREAM_ID])
# Contains the final total vertex and primitive counts in the current GS thread.
intrinsic("set_vertex_and_primitive_count", src_comp=[1, 1], indices=[STREAM_ID])
# 4. The total number of decomposed primitives emitted so far. This counts like
# the PRIMITIVES_GENERATED query: a triangle strip with 5 vertices is counted
# as 3 primitives (not 1).
intrinsic("emit_vertex_with_counter", src_comp=[1, 1, 1, 1], indices=[STREAM_ID])
intrinsic("end_primitive_with_counter", src_comp=[1, 1, 1, 1], indices=[STREAM_ID])
# Contains the final total vertex, primitive, and decomposed primitives counts
# in the current GS thread.
intrinsic("set_vertex_and_primitive_count", src_comp=[1, 1, 1], indices=[STREAM_ID])
# Launches mesh shader workgroups from a task shader, with explicit task_payload.
# Rules:

View File

@@ -59,14 +59,31 @@ struct state {
nir_variable *vertex_count_vars[NIR_MAX_XFB_STREAMS];
nir_variable *vtxcnt_per_prim_vars[NIR_MAX_XFB_STREAMS];
nir_variable *primitive_count_vars[NIR_MAX_XFB_STREAMS];
nir_variable *decomposed_primitive_count_vars[NIR_MAX_XFB_STREAMS];
bool per_stream;
bool count_prims;
bool count_vtx_per_prim;
bool count_decomposed_prims;
bool overwrite_incomplete;
bool is_points;
bool progress;
};
static unsigned
decomposed_primitive_size(nir_builder *b)
{
enum mesa_prim outprim = b->shader->info.gs.output_primitive;
if (outprim == MESA_PRIM_POINTS)
return 1;
else if (outprim == MESA_PRIM_LINE_STRIP)
return 2;
else if (outprim == MESA_PRIM_TRIANGLE_STRIP)
return 3;
else
unreachable("Invalid GS output primitive type.");
}
/**
* Replace emit_vertex intrinsics with:
*
@@ -88,6 +105,7 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
nir_def *count = nir_load_var(b, state->vertex_count_vars[stream]);
nir_def *count_per_primitive;
nir_def *primitive_count;
nir_def *decomposed_primitive_count;
if (state->count_vtx_per_prim)
count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]);
@@ -101,6 +119,13 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
else
primitive_count = nir_undef(b, 1, 32);
if (state->count_decomposed_prims) {
decomposed_primitive_count =
nir_load_var(b, state->decomposed_primitive_count_vars[stream]);
} else {
decomposed_primitive_count = nir_undef(b, 1, 32);
}
/* Create: if (vertex_count < max_vertices) and insert it.
*
* The new if statement needs to be hooked up to the control flow graph
@@ -109,7 +134,7 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
nir_push_if(b, nir_ilt_imm(b, count, b->shader->info.gs.vertices_out));
nir_emit_vertex_with_counter(b, count, count_per_primitive, primitive_count,
stream);
decomposed_primitive_count, stream);
/* Increment the vertex count by 1 */
nir_store_var(b, state->vertex_count_vars[stream],
@@ -125,6 +150,27 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
0x1); /* .x */
}
if (state->count_decomposed_prims) {
nir_variable *vtx_var = state->vtxcnt_per_prim_vars[stream];
nir_def *vtx_per_prim_cnt = state->is_points ? nir_imm_int(b, 1) :
nir_load_var(b, vtx_var);
/* We form a new primitive for every vertex emitted after the first
* complete primitive (since we're outputting strips).
*/
unsigned min_verts = decomposed_primitive_size(b);
nir_def *new_prim = nir_uge_imm(b, vtx_per_prim_cnt, min_verts);
/* Increment the decomposed primitive count by 1 if we formed a complete
* primitive.
*/
nir_variable *var = state->decomposed_primitive_count_vars[stream];
nir_def *cnt = nir_load_var(b, var);
nir_store_var(b, var,
nir_iadd(b, cnt, nir_b2i32(b, new_prim)),
0x1); /* .x */
}
nir_pop_if(b, NULL);
nir_instr_remove(&intrin->instr);
@@ -154,17 +200,7 @@ overwrite_incomplete_primitives(struct state *state, unsigned stream)
assert(state->count_vtx_per_prim);
nir_builder *b = state->builder;
enum mesa_prim outprim = b->shader->info.gs.output_primitive;
unsigned outprim_min_vertices;
if (outprim == MESA_PRIM_POINTS)
outprim_min_vertices = 1;
else if (outprim == MESA_PRIM_LINE_STRIP)
outprim_min_vertices = 2;
else if (outprim == MESA_PRIM_TRIANGLE_STRIP)
outprim_min_vertices = 3;
else
unreachable("Invalid GS output primitive type.");
unsigned outprim_min_vertices = decomposed_primitive_size(b);
/* Total count of vertices emitted so far. */
nir_def *vtxcnt_total =
@@ -221,6 +257,7 @@ rewrite_end_primitive(nir_intrinsic_instr *intrin, struct state *state)
nir_def *count = nir_load_var(b, state->vertex_count_vars[stream]);
nir_def *count_per_primitive;
nir_def *primitive_count;
nir_def *decomposed_primitive_count;
if (state->count_vtx_per_prim)
count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]);
@@ -232,8 +269,16 @@ rewrite_end_primitive(nir_intrinsic_instr *intrin, struct state *state)
else
primitive_count = nir_undef(b, 1, 32);
if (state->count_decomposed_prims) {
decomposed_primitive_count =
nir_load_var(b, state->decomposed_primitive_count_vars[stream]);
} else {
decomposed_primitive_count = nir_undef(b, 1, 32);
}
nir_end_primitive_with_counter(b, count, count_per_primitive,
primitive_count, stream);
primitive_count,
decomposed_primitive_count, stream);
if (state->count_prims) {
/* Increment the primitive count by 1 */
@@ -304,6 +349,7 @@ append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state)
nir_def *vtx_cnt;
nir_def *prim_cnt;
nir_def *decomposed_prim_cnt;
if (state->per_stream && !(shader->info.gs.active_stream_mask & (1 << stream))) {
/* Inactive stream: vertex count is 0, primitive count is 0 or undef. */
@@ -311,6 +357,7 @@ append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state)
prim_cnt = state->count_prims || state->is_points
? nir_imm_int(b, 0)
: nir_undef(b, 1, 32);
decomposed_prim_cnt = prim_cnt;
} else {
if (state->overwrite_incomplete)
overwrite_incomplete_primitives(state, stream);
@@ -326,9 +373,17 @@ append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state)
prim_cnt = vtx_cnt;
else
prim_cnt = nir_undef(b, 1, 32);
if (state->count_decomposed_prims) {
decomposed_prim_cnt =
nir_load_var(b, state->decomposed_primitive_count_vars[stream]);
} else {
decomposed_prim_cnt = nir_undef(b, 1, 32);
}
}
nir_set_vertex_and_primitive_count(b, vtx_cnt, prim_cnt, stream);
nir_set_vertex_and_primitive_count(b, vtx_cnt, prim_cnt,
decomposed_prim_cnt, stream);
state->progress = true;
}
}
@@ -411,6 +466,7 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
bool count_vtx_per_prim =
overwrite_incomplete ||
(options & nir_lower_gs_intrinsics_count_vertices_per_primitive);
bool count_decomposed_prims = options & nir_lower_gs_intrinsics_count_decomposed_primitives;
bool is_points = shader->info.gs.output_primitive == MESA_PRIM_POINTS;
/* points are always complete primitives with a single vertex, so these are
@@ -426,6 +482,7 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
state.progress = false;
state.count_prims = count_primitives;
state.count_vtx_per_prim = count_vtx_per_prim;
state.count_decomposed_prims = count_decomposed_prims;
state.overwrite_incomplete = overwrite_incomplete;
state.per_stream = per_stream;
state.is_points = is_points;
@@ -461,6 +518,13 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
/* initialize to 0 */
nir_store_var(&b, state.vtxcnt_per_prim_vars[i], nir_imm_int(&b, 0), 0x1);
}
if (count_decomposed_prims) {
state.decomposed_primitive_count_vars[i] =
nir_local_variable_create(impl, glsl_uint_type(), "decomposed_primitive_count");
/* initialize to 0 */
nir_store_var(&b, state.decomposed_primitive_count_vars[i],
nir_imm_int(&b, 0), 0x1);
}
} else {
/* If per_stream is false, we only have one counter of each kind which we
* want to use for all streams. Duplicate the counter pointers so all
@@ -472,6 +536,8 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
state.primitive_count_vars[i] = state.primitive_count_vars[0];
if (count_vtx_per_prim)
state.vtxcnt_per_prim_vars[i] = state.vtxcnt_per_prim_vars[0];
if (count_decomposed_prims)
state.decomposed_primitive_count_vars[i] = state.decomposed_primitive_count_vars[0];
}
}

View File

@@ -6654,7 +6654,8 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvOpSetMeshOutputsEXT:
nir_set_vertex_and_primitive_count(
&b->nb, vtn_get_nir_ssa(b, w[1]), vtn_get_nir_ssa(b, w[2]));
&b->nb, vtn_get_nir_ssa(b, w[1]), vtn_get_nir_ssa(b, w[2]),
nir_undef(&b->nb, 1, 32));
break;
case SpvOpInitializeNodePayloadsAMDX: