radeonsi: rewrite the prefix sum computation for shader culling

Instead of storing the vertex mask per wave into LDS and then computing
the prefix sum, store 8-bit bitcounts (vertex counts) of the vertex masks
into LDS. This allows us to compute the sum using v_sad_u8, which computes
a sum of 4 i8vec4 components in one instruction.

Each i8vec4 of vertex counts is loaded in parallel threads (one dword
per thread) instead of all being loaded in thread 0, and readlane copies
them to SGPRs instead of readfirstlane.

LDS is no longer initialized before culling. Instead, the counts for
inactive waves are masked with AND later.

Incorrect old comments are also fixed.

This change removes 80 bytes from the code size, and it allows increasing
the workgroup size from 128 to 256. (which is the main motivation for this)

Now changing the workgroup size with wave64 has no effect on the code size.
Switching to wave32 with 8 waves even generates slightly smaller code than
wave64 with 4 waves.

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/10813>
This commit is contained in:
Marek Olšák
2021-05-08 02:41:52 -04:00
committed by Marge Bot
parent 27c9e77c6a
commit 13acbaecd8
3 changed files with 133 additions and 151 deletions

View File

@@ -3293,17 +3293,21 @@ LLVMValueRef ac_trim_vector(struct ac_llvm_context *ctx, LLVMValueRef value, uns
return LLVMBuildShuffleVector(ctx->builder, value, value, swizzle, ""); return LLVMBuildShuffleVector(ctx->builder, value, value, swizzle, "");
} }
/* If param is i64 and bitwidth <= 32, the return value will be i32. */
LLVMValueRef ac_unpack_param(struct ac_llvm_context *ctx, LLVMValueRef param, unsigned rshift, LLVMValueRef ac_unpack_param(struct ac_llvm_context *ctx, LLVMValueRef param, unsigned rshift,
unsigned bitwidth) unsigned bitwidth)
{ {
LLVMValueRef value = param; LLVMValueRef value = param;
if (rshift) if (rshift)
value = LLVMBuildLShr(ctx->builder, value, LLVMConstInt(ctx->i32, rshift, false), ""); value = LLVMBuildLShr(ctx->builder, value, LLVMConstInt(LLVMTypeOf(param), rshift, false), "");
if (rshift + bitwidth < 32) { if (rshift + bitwidth < 32) {
unsigned mask = (1 << bitwidth) - 1; uint64_t mask = (1ull << bitwidth) - 1;
value = LLVMBuildAnd(ctx->builder, value, LLVMConstInt(ctx->i32, mask, false), ""); value = LLVMBuildAnd(ctx->builder, value, LLVMConstInt(LLVMTypeOf(param), mask, false), "");
} }
if (bitwidth <= 32 && LLVMTypeOf(param) == ctx->i64)
value = LLVMBuildTrunc(ctx->builder, value, ctx->i32, "");
return value; return value;
} }
@@ -4723,64 +4727,6 @@ void ac_build_s_endpgm(struct ac_llvm_context *ctx)
LLVMBuildCall(ctx->builder, code, NULL, 0, ""); LLVMBuildCall(ctx->builder, code, NULL, 0, "");
} }
LLVMValueRef ac_prefix_bitcount(struct ac_llvm_context *ctx, LLVMValueRef mask, LLVMValueRef index)
{
LLVMBuilderRef builder = ctx->builder;
LLVMTypeRef type = LLVMTypeOf(mask);
LLVMValueRef bit =
LLVMBuildShl(builder, LLVMConstInt(type, 1, 0), LLVMBuildZExt(builder, index, type, ""), "");
LLVMValueRef prefix_bits = LLVMBuildSub(builder, bit, LLVMConstInt(type, 1, 0), "");
LLVMValueRef prefix_mask = LLVMBuildAnd(builder, mask, prefix_bits, "");
return ac_build_bit_count(ctx, prefix_mask);
}
/* Compute the prefix sum of the "mask" bit array with 128 elements (bits). */
LLVMValueRef ac_prefix_bitcount_2x64(struct ac_llvm_context *ctx, LLVMValueRef mask[2],
LLVMValueRef index)
{
LLVMBuilderRef builder = ctx->builder;
#if 0
/* Reference version using i128. */
LLVMValueRef input_mask =
LLVMBuildBitCast(builder, ac_build_gather_values(ctx, mask, 2), ctx->i128, "");
return ac_prefix_bitcount(ctx, input_mask, index);
#else
/* Optimized version using 2 64-bit masks. */
LLVMValueRef is_hi, is_0, c64, c128, all_bits;
LLVMValueRef prefix_mask[2], shift[2], mask_bcnt0, prefix_bcnt[2];
/* Compute the 128-bit prefix mask. */
c64 = LLVMConstInt(ctx->i32, 64, 0);
c128 = LLVMConstInt(ctx->i32, 128, 0);
all_bits = LLVMConstInt(ctx->i64, UINT64_MAX, 0);
/* The first index that can have non-zero high bits in the prefix mask is 65. */
is_hi = LLVMBuildICmp(builder, LLVMIntUGT, index, c64, "");
is_0 = LLVMBuildICmp(builder, LLVMIntEQ, index, ctx->i32_0, "");
mask_bcnt0 = ac_build_bit_count(ctx, mask[0]);
for (unsigned i = 0; i < 2; i++) {
shift[i] = LLVMBuildSub(builder, i ? c128 : c64, index, "");
/* For i==0, index==0, the right shift by 64 doesn't give the desired result,
* so we handle it by the is_0 select.
* For i==1, index==64, same story, so we handle it by the last is_hi select.
* For i==0, index==64, we shift by 0, which is what we want.
*/
prefix_mask[i] =
LLVMBuildLShr(builder, all_bits, LLVMBuildZExt(builder, shift[i], ctx->i64, ""), "");
prefix_mask[i] = LLVMBuildAnd(builder, mask[i], prefix_mask[i], "");
prefix_bcnt[i] = ac_build_bit_count(ctx, prefix_mask[i]);
}
prefix_bcnt[0] = LLVMBuildSelect(builder, is_0, ctx->i32_0, prefix_bcnt[0], "");
prefix_bcnt[0] = LLVMBuildSelect(builder, is_hi, mask_bcnt0, prefix_bcnt[0], "");
prefix_bcnt[1] = LLVMBuildSelect(builder, is_hi, prefix_bcnt[1], ctx->i32_0, "");
return LLVMBuildAdd(builder, prefix_bcnt[0], prefix_bcnt[1], "");
#endif
}
/** /**
* Convert triangle strip indices to triangle indices. This is used to decompose * Convert triangle strip indices to triangle indices. This is used to decompose
* triangle strips into triangles. * triangle strips into triangles.

View File

@@ -607,9 +607,6 @@ LLVMValueRef ac_build_main(const struct ac_shader_args *args, struct ac_llvm_con
LLVMTypeRef ret_type, LLVMModuleRef module); LLVMTypeRef ret_type, LLVMModuleRef module);
void ac_build_s_endpgm(struct ac_llvm_context *ctx); void ac_build_s_endpgm(struct ac_llvm_context *ctx);
LLVMValueRef ac_prefix_bitcount(struct ac_llvm_context *ctx, LLVMValueRef mask, LLVMValueRef index);
LLVMValueRef ac_prefix_bitcount_2x64(struct ac_llvm_context *ctx, LLVMValueRef mask[2],
LLVMValueRef index);
void ac_build_triangle_strip_indices_to_triangle(struct ac_llvm_context *ctx, LLVMValueRef is_odd, void ac_build_triangle_strip_indices_to_triangle(struct ac_llvm_context *ctx, LLVMValueRef is_odd,
LLVMValueRef flatshade_first, LLVMValueRef flatshade_first,
LLVMValueRef index[3]); LLVMValueRef index[3]);

View File

@@ -577,15 +577,20 @@ enum
lds_tes_patch_id, /* optional */ lds_tes_patch_id, /* optional */
}; };
static LLVMValueRef si_build_gep_i8_var(struct si_shader_context *ctx, LLVMValueRef ptr,
LLVMValueRef index)
{
LLVMTypeRef pi8 = LLVMPointerType(ctx->ac.i8, AC_ADDR_SPACE_LDS);
return LLVMBuildGEP(ctx->ac.builder, LLVMBuildPointerCast(ctx->ac.builder, ptr, pi8, ""), &index,
1, "");
}
static LLVMValueRef si_build_gep_i8(struct si_shader_context *ctx, LLVMValueRef ptr, static LLVMValueRef si_build_gep_i8(struct si_shader_context *ctx, LLVMValueRef ptr,
unsigned byte_index) unsigned byte_index)
{ {
assert(byte_index < 4); assert(byte_index < 4);
LLVMTypeRef pi8 = LLVMPointerType(ctx->ac.i8, AC_ADDR_SPACE_LDS); return si_build_gep_i8_var(ctx, ptr, LLVMConstInt(ctx->ac.i32, byte_index, 0));
LLVMValueRef index = LLVMConstInt(ctx->ac.i32, byte_index, 0);
return LLVMBuildGEP(ctx->ac.builder, LLVMBuildPointerCast(ctx->ac.builder, ptr, pi8, ""), &index,
1, "");
} }
static unsigned ngg_nogs_vertex_size(struct si_shader *shader) static unsigned ngg_nogs_vertex_size(struct si_shader *shader)
@@ -653,42 +658,87 @@ static LLVMValueRef si_insert_input_v4i32(struct si_shader_context *ctx, LLVMVal
return ret; return ret;
} }
static void load_bitmasks_2x64(struct si_shader_context *ctx, LLVMValueRef lds_ptr, static void load_vertex_counts(struct si_shader_context *ctx, LLVMValueRef lds,
LLVMValueRef tid, unsigned max_waves, LLVMValueRef tid,
unsigned dw_offset, LLVMValueRef mask[4], LLVMValueRef *total_count,
LLVMValueRef *total_bitcount) LLVMValueRef *prefix_sum)
{ {
LLVMBuilderRef builder = ctx->ac.builder; LLVMBuilderRef builder = ctx->ac.builder;
LLVMValueRef ptr64 = LLVMBuildPointerCast( LLVMValueRef i8vec4_lane = ac_build_alloca_undef(&ctx->ac, ctx->ac.i32, "");
builder, lds_ptr, LLVMPointerType(LLVMArrayType(ctx->ac.i64, 2), AC_ADDR_SPACE_LDS), ""); unsigned num_i8vec4 = DIV_ROUND_UP(max_waves, 4);
LLVMValueRef tmp[2];
for (unsigned i = 0; i < 2; i++) /* If all threads loaded the vertex counts, it would cause many LDS bank conflicts
tmp[i] = ac_build_alloca_undef(&ctx->ac, ctx->ac.i64, "");
/* If all threads loaded the bitmasks, it would cause many LDS bank conflicts
* and the performance could decrease up to WaveSize times (32x or 64x). * and the performance could decrease up to WaveSize times (32x or 64x).
* *
* Therefore, only load the bitmasks in thread 0 and other threads will get them * Therefore, only load the i-th tuple of vertex counts in the i-th thread. Other threads will
* through readlane. * get them through readlane. 4 8-bit vertex counts are loaded per thread.
*/ */
ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntEQ, tid, ctx->ac.i32_0, ""), 17771); ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntULT, tid,
for (unsigned i = 0; i < 2; i++) { LLVMConstInt(ctx->ac.i32, num_i8vec4, 0), ""), 17771);
LLVMValueRef index = LLVMConstInt(ctx->ac.i32, dw_offset / 2 + i, 0); LLVMBuildStore(builder, LLVMBuildLoad(builder, ac_build_gep0(&ctx->ac, lds, tid), ""), i8vec4_lane);
LLVMValueRef val = LLVMBuildLoad(builder, ac_build_gep0(&ctx->ac, ptr64, index), "");
LLVMBuildStore(builder, val, tmp[i]);
}
ac_build_endif(&ctx->ac, 17771); ac_build_endif(&ctx->ac, 17771);
*total_bitcount = ctx->ac.i32_0; /* Compute the number of ES waves. */
LLVMValueRef num_waves = get_tgsize(ctx);
for (unsigned i = 0; i < 2; i++) { /* Compute a byte mask where each byte is either 0 or 0xff depending on whether the wave
tmp[i] = LLVMBuildLoad(builder, tmp[i], ""); * exists. We need the mask to clear uninitialized bytes in LDS and to compute the prefix sum.
mask[i] = ac_build_readlane_no_opt_barrier(&ctx->ac, tmp[i], NULL); *
* 8 waves: valid_mask = ~0ull >> (64 - num_waves * 8)
* 4 waves: valid_mask = ~0 >> (32 - num_waves * 8)
*/
LLVMValueRef num_waves8 = LLVMBuildShl(builder, num_waves, LLVMConstInt(ctx->ac.i32, 3, 0), "");
LLVMValueRef valid_mask;
*total_bitcount = LLVMBuildAdd(builder, *total_bitcount, if (max_waves > 4) {
ac_build_bit_count(&ctx->ac, mask[i]), ""); LLVMValueRef num_waves8_rev = LLVMBuildSub(builder, LLVMConstInt(ctx->ac.i32, 64, 0),
num_waves8, "");
valid_mask = LLVMBuildLShr(builder, LLVMConstInt(ctx->ac.i64, ~0ull, 0),
LLVMBuildZExt(builder, num_waves8_rev, ctx->ac.i64, ""), "");
} else {
LLVMValueRef num_waves8_rev = LLVMBuildSub(builder, LLVMConstInt(ctx->ac.i32, 32, 0),
num_waves8, "");
valid_mask = LLVMBuildLShr(builder, LLVMConstInt(ctx->ac.i32, ~0, 0), num_waves8_rev, "");
} }
/* Compute a byte mask where bytes below wave_id are 0xff, else they are 0.
*
* prefix_mask = ~(~0 << (wave_id * 8))
*/
LLVMTypeRef type = max_waves > 4 ? ctx->ac.i64 : ctx->ac.i32;
LLVMValueRef wave_id8 = LLVMBuildShl(builder, get_wave_id_in_tg(ctx),
LLVMConstInt(ctx->ac.i32, 3, 0), "");
LLVMValueRef prefix_mask =
LLVMBuildNot(builder, LLVMBuildShl(builder, LLVMConstInt(type, ~0ull, 0),
LLVMBuildZExt(builder, wave_id8, type, ""), ""), "");
/* Compute the total vertex count and the vertex count of previous waves (prefix). */
*total_count = ctx->ac.i32_0;
*prefix_sum = ctx->ac.i32_0;
for (unsigned i = 0; i < num_i8vec4; i++) {
LLVMValueRef i8vec4;
i8vec4 = ac_build_readlane_no_opt_barrier(&ctx->ac, LLVMBuildLoad(builder, i8vec4_lane, ""),
LLVMConstInt(ctx->ac.i32, i, 0));
/* Inactive waves have uninitialized vertex counts. Set them to 0 using this. */
i8vec4 = LLVMBuildAnd(builder, i8vec4,
ac_unpack_param(&ctx->ac, valid_mask, 32 * i, 32), "");
/* Compute the sum of all i8vec4 components and add it to the result. */
*total_count = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.sad.u8", ctx->ac.i32,
(LLVMValueRef[]){i8vec4, ctx->ac.i32_0, *total_count},
3, AC_FUNC_ATTR_READNONE);
ac_set_range_metadata(&ctx->ac, *total_count, 0, 64*4 + 1); /* the result is at most 64*4 */
/* Compute the sum of the vertex counts of all previous waves. */
i8vec4 = LLVMBuildAnd(builder, i8vec4,
ac_unpack_param(&ctx->ac, prefix_mask, 32 * i, 32), "");
*prefix_sum = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.sad.u8", ctx->ac.i32,
(LLVMValueRef[]){i8vec4, ctx->ac.i32_0, *prefix_sum},
3, AC_FUNC_ATTR_READNONE);
ac_set_range_metadata(&ctx->ac, *prefix_sum, 0, 64*4 + 1); /* the result is at most 64*4 */
}
*total_count = ac_build_readlane_no_opt_barrier(&ctx->ac, *total_count, NULL);
} }
/** /**
@@ -746,14 +796,9 @@ void gfx10_emit_ngg_culling_epilogue(struct ac_shader_abi *abi, unsigned max_out
struct si_shader_selector *sel = shader->selector; struct si_shader_selector *sel = shader->selector;
struct si_shader_info *info = &sel->info; struct si_shader_info *info = &sel->info;
LLVMBuilderRef builder = ctx->ac.builder; LLVMBuilderRef builder = ctx->ac.builder;
unsigned max_waves = ctx->ac.wave_size == 64 ? 2 : 4; unsigned subgroup_size = 128;
LLVMValueRef ngg_scratch = ctx->gs_ngg_scratch; unsigned max_waves = ctx->ac.wave_size == 64 ? DIV_ROUND_UP(subgroup_size, 64) :
DIV_ROUND_UP(subgroup_size, 32);
if (ctx->ac.wave_size == 64) {
ngg_scratch = LLVMBuildPointerCast(builder, ngg_scratch,
LLVMPointerType(LLVMArrayType(ctx->ac.i64, max_waves),
AC_ADDR_SPACE_LDS), "");
}
assert(shader->key.opt.ngg_culling); assert(shader->key.opt.ngg_culling);
assert(shader->key.as_ngg); assert(shader->key.as_ngg);
@@ -805,48 +850,32 @@ void gfx10_emit_ngg_culling_epilogue(struct ac_shader_abi *abi, unsigned max_out
builder, ctx->ac.i32_0, builder, ctx->ac.i32_0,
ac_build_gep0(&ctx->ac, es_vtxptr, LLVMConstInt(ctx->ac.i32, lds_packed_data, 0))); ac_build_gep0(&ctx->ac, es_vtxptr, LLVMConstInt(ctx->ac.i32, lds_packed_data, 0)));
ac_build_endif(&ctx->ac, ctx->merged_wrap_if_label); ac_build_endif(&ctx->ac, ctx->merged_wrap_if_label);
ac_build_s_barrier(&ctx->ac);
LLVMValueRef tid = ac_get_thread_id(&ctx->ac); LLVMValueRef tid = ac_get_thread_id(&ctx->ac);
/* Initialize all but the first element of ngg_scratch to 0, because we may have less
* than the maximum number of waves, but we always read all values. This is where
* the thread bitmasks of unculled threads will be stored.
*
* ngg_scratch layout: iN_wavemask esmask[0..n]
*/
ac_build_ifcc(&ctx->ac,
LLVMBuildICmp(builder, LLVMIntULT, get_thread_id_in_tg(ctx),
LLVMConstInt(ctx->ac.i32, max_waves - 1, 0), ""),
16101);
{
LLVMValueRef index = LLVMBuildAdd(builder, tid, ctx->ac.i32_1, "");
LLVMBuildStore(builder, LLVMConstInt(ctx->ac.iN_wavemask, 0, 0),
ac_build_gep0(&ctx->ac, ngg_scratch, index));
}
ac_build_endif(&ctx->ac, 16101);
ac_build_s_barrier(&ctx->ac);
/* The hardware requires that there are no holes between unculled vertices, /* The hardware requires that there are no holes between unculled vertices,
* which means we have to pack ES threads, i.e. reduce the ES thread count * which means we have to pack ES threads, i.e. reduce the ES thread count
* and move ES input VGPRs to lower threads. The upside is that varyings * and move ES input VGPRs to lower threads. The upside is that varyings
* are only fetched and computed for unculled vertices. * are only fetched and computed for unculled vertices.
* *
* Vertex compaction in GS threads: * Vertex compaction:
* *
* Part 1: Compute the surviving vertex mask in GS threads: * Part 1: Store the surviving vertex count for each wave in LDS.
* - Compute 4 32-bit surviving vertex masks in LDS. (max 4 waves) * - The GS culling code notifies ES threads which vertices were accepted.
* - In GS, notify ES threads whether the vertex survived.
* - Barrier * - Barrier
* - ES threads will create the mask and store it in LDS. * - ES threads will compute the vertex count and store it in LDS.
* - Barrier * - Barrier
* - Each GS thread loads the vertex masks from LDS. * - Each wave loads the vertex counts from LDS.
* *
* Part 2: Compact ES threads in GS threads: * Part 2: Compact ES threads:
* - Compute the prefix sum for all 3 vertices from the masks. These are the new * - Compute the prefix sum for each surviving vertex. This is the new thread ID
* thread IDs for each vertex within the primitive. * of the vertex.
* - Write input VGPRs and vertex positions into the LDS address of the new thread ID. * - Write input VGPRs and vertex positions for each surviving vertex into the LDS
* address of the new thread ID.
* - Now kill all waves that have inactive threads.
* - Barrier
* - Update vertex indices and null flag in the GS input VGPRs. * - Update vertex indices and null flag in the GS input VGPRs.
* - Barrier
* *
* Part 3: Update inputs GPRs * Part 3: Update inputs GPRs
* - For all waves, update per-wave thread counts in input SGPRs. * - For all waves, update per-wave thread counts in input SGPRs.
@@ -972,33 +1001,38 @@ void gfx10_emit_ngg_culling_epilogue(struct ac_shader_abi *abi, unsigned max_out
gs_accepted = LLVMBuildLoad(builder, gs_accepted, ""); gs_accepted = LLVMBuildLoad(builder, gs_accepted, "");
LLVMValueRef es_accepted = ac_build_alloca(&ctx->ac, ctx->ac.i1, ""); LLVMValueRef vertex_accepted = ac_build_alloca(&ctx->ac, ctx->ac.i1, "");
LLVMValueRef vertex_mask = ac_build_alloca(&ctx->ac, ctx->ac.iN_wavemask, "");
/* Convert the per-vertex flag to a thread bitmask in ES threads and store it in LDS. */ /* Convert the per-vertex accept flag to a vertex thread mask, store it in registers. */
ac_build_ifcc(&ctx->ac, si_is_es_thread(ctx), 16007); ac_build_ifcc(&ctx->ac, si_is_es_thread(ctx), 16007);
{ {
LLVMValueRef es_accepted_flag = LLVMValueRef accepted =
LLVMBuildLoad(builder, si_build_gep_i8(ctx, es_vtxptr, lds_byte0_accept_flag), ""); LLVMBuildLoad(builder, si_build_gep_i8(ctx, es_vtxptr, lds_byte0_accept_flag), "");
accepted = LLVMBuildICmp(builder, LLVMIntNE, accepted, ctx->ac.i8_0, "");
LLVMValueRef mask = ac_get_i1_sgpr_mask(&ctx->ac, accepted);
LLVMValueRef es_accepted_bool = LLVMBuildStore(builder, accepted, vertex_accepted);
LLVMBuildICmp(builder, LLVMIntNE, es_accepted_flag, ctx->ac.i8_0, ""); LLVMBuildStore(builder, mask, vertex_mask);
LLVMValueRef es_mask = ac_get_i1_sgpr_mask(&ctx->ac, es_accepted_bool);
LLVMBuildStore(builder, es_accepted_bool, es_accepted);
ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntEQ, tid, ctx->ac.i32_0, ""), 16008);
{
LLVMBuildStore(builder, es_mask,
ac_build_gep0(&ctx->ac, ngg_scratch, get_wave_id_in_tg(ctx)));
}
ac_build_endif(&ctx->ac, 16008);
} }
ac_build_endif(&ctx->ac, 16007); ac_build_endif(&ctx->ac, 16007);
/* Store the per-wave vertex count to LDS. Non-ES waves store 0. */
vertex_mask = LLVMBuildLoad(builder, vertex_mask, "");
ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntEQ, tid, ctx->ac.i32_0, ""), 16008);
{
LLVMValueRef vertex_count = ac_build_bit_count(&ctx->ac, vertex_mask);
LLVMBuildStore(builder, LLVMBuildTrunc(builder, vertex_count, ctx->ac.i8, ""),
si_build_gep_i8_var(ctx, ctx->gs_ngg_scratch, get_wave_id_in_tg(ctx)));
}
ac_build_endif(&ctx->ac, 16008);
ac_build_s_barrier(&ctx->ac); ac_build_s_barrier(&ctx->ac);
/* Load the vertex masks and compute the new ES thread count. */ /* Load the vertex masks and compute the new ES thread count. */
LLVMValueRef es_mask[2], new_num_es_threads, kill_wave; LLVMValueRef new_num_es_threads, prefix_sum, kill_wave;
load_bitmasks_2x64(ctx, ngg_scratch, tid, 0, es_mask, &new_num_es_threads); load_vertex_counts(ctx, ctx->gs_ngg_scratch, max_waves, tid, &new_num_es_threads,
&prefix_sum);
bool uses_instance_id = ctx->stage == MESA_SHADER_VERTEX && bool uses_instance_id = ctx->stage == MESA_SHADER_VERTEX &&
(sel->info.uses_instanceid || (sel->info.uses_instanceid ||
@@ -1012,10 +1046,15 @@ void gfx10_emit_ngg_culling_epilogue(struct ac_shader_abi *abi, unsigned max_out
* of the new thread ID. It will be used to load input VGPRs by compacted * of the new thread ID. It will be used to load input VGPRs by compacted
* threads. * threads.
*/ */
ac_build_ifcc(&ctx->ac, LLVMBuildLoad(builder, es_accepted, ""), 16009); vertex_accepted = LLVMBuildLoad(builder, vertex_accepted, "");
ac_build_ifcc(&ctx->ac, vertex_accepted, 16009);
{ {
LLVMValueRef old_id = get_thread_id_in_tg(ctx); /* Add the number of bits set in vertex_mask up to the current thread ID - 1
LLVMValueRef new_id = ac_prefix_bitcount_2x64(&ctx->ac, es_mask, old_id); * to get the prefix sum.
*/
prefix_sum = LLVMBuildAdd(builder, prefix_sum, ac_build_mbcnt(&ctx->ac, vertex_mask), "");
LLVMValueRef new_id = prefix_sum;
LLVMValueRef new_vtx = ngg_nogs_vertex_ptr(ctx, new_id); LLVMValueRef new_vtx = ngg_nogs_vertex_ptr(ctx, new_id);
LLVMBuildStore(builder, LLVMBuildTrunc(builder, new_id, ctx->ac.i8, ""), LLVMBuildStore(builder, LLVMBuildTrunc(builder, new_id, ctx->ac.i8, ""),