diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index f351e3ca301..d47d772ccbb 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -6030,6 +6030,7 @@ typedef enum { nir_lower_gs_intrinsics_count_vertices_per_primitive = 1 << 2, nir_lower_gs_intrinsics_overwrite_incomplete = 1 << 3, nir_lower_gs_intrinsics_always_end_primitive = 1 << 4, + nir_lower_gs_intrinsics_count_decomposed_primitives = 1 << 5, } nir_lower_gs_intrinsics_flags; bool nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags options); diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index e5c9c592112..1f7048f646f 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -536,15 +536,19 @@ intrinsic("end_primitive", indices=[STREAM_ID]) # Alternatively, drivers may implement these intrinsics, and use # nir_lower_gs_intrinsics() to convert from the basic intrinsics. # -# These contain three additional unsigned integer sources: +# These contain four additional unsigned integer sources: # 1. The total number of vertices emitted so far. # 2. The number of vertices emitted for the current primitive # so far if we're counting, otherwise undef. # 3. The total number of primitives emitted so far. -intrinsic("emit_vertex_with_counter", src_comp=[1, 1, 1], indices=[STREAM_ID]) -intrinsic("end_primitive_with_counter", src_comp=[1, 1, 1], indices=[STREAM_ID]) -# Contains the final total vertex and primitive counts in the current GS thread. -intrinsic("set_vertex_and_primitive_count", src_comp=[1, 1], indices=[STREAM_ID]) +# 4. The total number of decomposed primitives emitted so far. This counts like +# the PRIMITIVES_GENERATED query: a triangle strip with 5 vertices is counted +# as 3 primitives (not 1). +intrinsic("emit_vertex_with_counter", src_comp=[1, 1, 1, 1], indices=[STREAM_ID]) +intrinsic("end_primitive_with_counter", src_comp=[1, 1, 1, 1], indices=[STREAM_ID]) +# Contains the final total vertex, primitive, and decomposed primitives counts +# in the current GS thread. +intrinsic("set_vertex_and_primitive_count", src_comp=[1, 1, 1], indices=[STREAM_ID]) # Launches mesh shader workgroups from a task shader, with explicit task_payload. # Rules: diff --git a/src/compiler/nir/nir_lower_gs_intrinsics.c b/src/compiler/nir/nir_lower_gs_intrinsics.c index 735abf60724..15cc69663fd 100644 --- a/src/compiler/nir/nir_lower_gs_intrinsics.c +++ b/src/compiler/nir/nir_lower_gs_intrinsics.c @@ -59,14 +59,31 @@ struct state { nir_variable *vertex_count_vars[NIR_MAX_XFB_STREAMS]; nir_variable *vtxcnt_per_prim_vars[NIR_MAX_XFB_STREAMS]; nir_variable *primitive_count_vars[NIR_MAX_XFB_STREAMS]; + nir_variable *decomposed_primitive_count_vars[NIR_MAX_XFB_STREAMS]; bool per_stream; bool count_prims; bool count_vtx_per_prim; + bool count_decomposed_prims; bool overwrite_incomplete; bool is_points; bool progress; }; +static unsigned +decomposed_primitive_size(nir_builder *b) +{ + enum mesa_prim outprim = b->shader->info.gs.output_primitive; + + if (outprim == MESA_PRIM_POINTS) + return 1; + else if (outprim == MESA_PRIM_LINE_STRIP) + return 2; + else if (outprim == MESA_PRIM_TRIANGLE_STRIP) + return 3; + else + unreachable("Invalid GS output primitive type."); +} + /** * Replace emit_vertex intrinsics with: * @@ -88,6 +105,7 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state) nir_def *count = nir_load_var(b, state->vertex_count_vars[stream]); nir_def *count_per_primitive; nir_def *primitive_count; + nir_def *decomposed_primitive_count; if (state->count_vtx_per_prim) count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]); @@ -101,6 +119,13 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state) else primitive_count = nir_undef(b, 1, 32); + if (state->count_decomposed_prims) { + decomposed_primitive_count = + nir_load_var(b, state->decomposed_primitive_count_vars[stream]); + } else { + decomposed_primitive_count = nir_undef(b, 1, 32); + } + /* Create: if (vertex_count < max_vertices) and insert it. * * The new if statement needs to be hooked up to the control flow graph @@ -109,7 +134,7 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state) nir_push_if(b, nir_ilt_imm(b, count, b->shader->info.gs.vertices_out)); nir_emit_vertex_with_counter(b, count, count_per_primitive, primitive_count, - stream); + decomposed_primitive_count, stream); /* Increment the vertex count by 1 */ nir_store_var(b, state->vertex_count_vars[stream], @@ -125,6 +150,27 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state) 0x1); /* .x */ } + if (state->count_decomposed_prims) { + nir_variable *vtx_var = state->vtxcnt_per_prim_vars[stream]; + nir_def *vtx_per_prim_cnt = state->is_points ? nir_imm_int(b, 1) : + nir_load_var(b, vtx_var); + + /* We form a new primitive for every vertex emitted after the first + * complete primitive (since we're outputting strips). + */ + unsigned min_verts = decomposed_primitive_size(b); + nir_def *new_prim = nir_uge_imm(b, vtx_per_prim_cnt, min_verts); + + /* Increment the decomposed primitive count by 1 if we formed a complete + * primitive. + */ + nir_variable *var = state->decomposed_primitive_count_vars[stream]; + nir_def *cnt = nir_load_var(b, var); + nir_store_var(b, var, + nir_iadd(b, cnt, nir_b2i32(b, new_prim)), + 0x1); /* .x */ + } + nir_pop_if(b, NULL); nir_instr_remove(&intrin->instr); @@ -154,17 +200,7 @@ overwrite_incomplete_primitives(struct state *state, unsigned stream) assert(state->count_vtx_per_prim); nir_builder *b = state->builder; - enum mesa_prim outprim = b->shader->info.gs.output_primitive; - unsigned outprim_min_vertices; - - if (outprim == MESA_PRIM_POINTS) - outprim_min_vertices = 1; - else if (outprim == MESA_PRIM_LINE_STRIP) - outprim_min_vertices = 2; - else if (outprim == MESA_PRIM_TRIANGLE_STRIP) - outprim_min_vertices = 3; - else - unreachable("Invalid GS output primitive type."); + unsigned outprim_min_vertices = decomposed_primitive_size(b); /* Total count of vertices emitted so far. */ nir_def *vtxcnt_total = @@ -221,6 +257,7 @@ rewrite_end_primitive(nir_intrinsic_instr *intrin, struct state *state) nir_def *count = nir_load_var(b, state->vertex_count_vars[stream]); nir_def *count_per_primitive; nir_def *primitive_count; + nir_def *decomposed_primitive_count; if (state->count_vtx_per_prim) count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]); @@ -232,8 +269,16 @@ rewrite_end_primitive(nir_intrinsic_instr *intrin, struct state *state) else primitive_count = nir_undef(b, 1, 32); + if (state->count_decomposed_prims) { + decomposed_primitive_count = + nir_load_var(b, state->decomposed_primitive_count_vars[stream]); + } else { + decomposed_primitive_count = nir_undef(b, 1, 32); + } + nir_end_primitive_with_counter(b, count, count_per_primitive, - primitive_count, stream); + primitive_count, + decomposed_primitive_count, stream); if (state->count_prims) { /* Increment the primitive count by 1 */ @@ -304,6 +349,7 @@ append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state) nir_def *vtx_cnt; nir_def *prim_cnt; + nir_def *decomposed_prim_cnt; if (state->per_stream && !(shader->info.gs.active_stream_mask & (1 << stream))) { /* Inactive stream: vertex count is 0, primitive count is 0 or undef. */ @@ -311,6 +357,7 @@ append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state) prim_cnt = state->count_prims || state->is_points ? nir_imm_int(b, 0) : nir_undef(b, 1, 32); + decomposed_prim_cnt = prim_cnt; } else { if (state->overwrite_incomplete) overwrite_incomplete_primitives(state, stream); @@ -326,9 +373,17 @@ append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state) prim_cnt = vtx_cnt; else prim_cnt = nir_undef(b, 1, 32); + + if (state->count_decomposed_prims) { + decomposed_prim_cnt = + nir_load_var(b, state->decomposed_primitive_count_vars[stream]); + } else { + decomposed_prim_cnt = nir_undef(b, 1, 32); + } } - nir_set_vertex_and_primitive_count(b, vtx_cnt, prim_cnt, stream); + nir_set_vertex_and_primitive_count(b, vtx_cnt, prim_cnt, + decomposed_prim_cnt, stream); state->progress = true; } } @@ -411,6 +466,7 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option bool count_vtx_per_prim = overwrite_incomplete || (options & nir_lower_gs_intrinsics_count_vertices_per_primitive); + bool count_decomposed_prims = options & nir_lower_gs_intrinsics_count_decomposed_primitives; bool is_points = shader->info.gs.output_primitive == MESA_PRIM_POINTS; /* points are always complete primitives with a single vertex, so these are @@ -426,6 +482,7 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option state.progress = false; state.count_prims = count_primitives; state.count_vtx_per_prim = count_vtx_per_prim; + state.count_decomposed_prims = count_decomposed_prims; state.overwrite_incomplete = overwrite_incomplete; state.per_stream = per_stream; state.is_points = is_points; @@ -461,6 +518,13 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option /* initialize to 0 */ nir_store_var(&b, state.vtxcnt_per_prim_vars[i], nir_imm_int(&b, 0), 0x1); } + if (count_decomposed_prims) { + state.decomposed_primitive_count_vars[i] = + nir_local_variable_create(impl, glsl_uint_type(), "decomposed_primitive_count"); + /* initialize to 0 */ + nir_store_var(&b, state.decomposed_primitive_count_vars[i], + nir_imm_int(&b, 0), 0x1); + } } else { /* If per_stream is false, we only have one counter of each kind which we * want to use for all streams. Duplicate the counter pointers so all @@ -472,6 +536,8 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option state.primitive_count_vars[i] = state.primitive_count_vars[0]; if (count_vtx_per_prim) state.vtxcnt_per_prim_vars[i] = state.vtxcnt_per_prim_vars[0]; + if (count_decomposed_prims) + state.decomposed_primitive_count_vars[i] = state.decomposed_primitive_count_vars[0]; } } diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index ccb5497edc6..6747c0b9771 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -6654,7 +6654,8 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpSetMeshOutputsEXT: nir_set_vertex_and_primitive_count( - &b->nb, vtn_get_nir_ssa(b, w[1]), vtn_get_nir_ssa(b, w[2])); + &b->nb, vtn_get_nir_ssa(b, w[1]), vtn_get_nir_ssa(b, w[2]), + nir_undef(&b->nb, 1, 32)); break; case SpvOpInitializeNodePayloadsAMDX: