diff --git a/src/amd/llvm/ac_llvm_build.c b/src/amd/llvm/ac_llvm_build.c index fefe383806b..43bafb81443 100644 --- a/src/amd/llvm/ac_llvm_build.c +++ b/src/amd/llvm/ac_llvm_build.c @@ -3293,17 +3293,21 @@ LLVMValueRef ac_trim_vector(struct ac_llvm_context *ctx, LLVMValueRef value, uns return LLVMBuildShuffleVector(ctx->builder, value, value, swizzle, ""); } +/* If param is i64 and bitwidth <= 32, the return value will be i32. */ LLVMValueRef ac_unpack_param(struct ac_llvm_context *ctx, LLVMValueRef param, unsigned rshift, unsigned bitwidth) { LLVMValueRef value = param; if (rshift) - value = LLVMBuildLShr(ctx->builder, value, LLVMConstInt(ctx->i32, rshift, false), ""); + value = LLVMBuildLShr(ctx->builder, value, LLVMConstInt(LLVMTypeOf(param), rshift, false), ""); if (rshift + bitwidth < 32) { - unsigned mask = (1 << bitwidth) - 1; - value = LLVMBuildAnd(ctx->builder, value, LLVMConstInt(ctx->i32, mask, false), ""); + uint64_t mask = (1ull << bitwidth) - 1; + value = LLVMBuildAnd(ctx->builder, value, LLVMConstInt(LLVMTypeOf(param), mask, false), ""); } + + if (bitwidth <= 32 && LLVMTypeOf(param) == ctx->i64) + value = LLVMBuildTrunc(ctx->builder, value, ctx->i32, ""); return value; } @@ -4723,64 +4727,6 @@ void ac_build_s_endpgm(struct ac_llvm_context *ctx) LLVMBuildCall(ctx->builder, code, NULL, 0, ""); } -LLVMValueRef ac_prefix_bitcount(struct ac_llvm_context *ctx, LLVMValueRef mask, LLVMValueRef index) -{ - LLVMBuilderRef builder = ctx->builder; - LLVMTypeRef type = LLVMTypeOf(mask); - - LLVMValueRef bit = - LLVMBuildShl(builder, LLVMConstInt(type, 1, 0), LLVMBuildZExt(builder, index, type, ""), ""); - LLVMValueRef prefix_bits = LLVMBuildSub(builder, bit, LLVMConstInt(type, 1, 0), ""); - LLVMValueRef prefix_mask = LLVMBuildAnd(builder, mask, prefix_bits, ""); - return ac_build_bit_count(ctx, prefix_mask); -} - -/* Compute the prefix sum of the "mask" bit array with 128 elements (bits). */ -LLVMValueRef ac_prefix_bitcount_2x64(struct ac_llvm_context *ctx, LLVMValueRef mask[2], - LLVMValueRef index) -{ - LLVMBuilderRef builder = ctx->builder; -#if 0 - /* Reference version using i128. */ - LLVMValueRef input_mask = - LLVMBuildBitCast(builder, ac_build_gather_values(ctx, mask, 2), ctx->i128, ""); - - return ac_prefix_bitcount(ctx, input_mask, index); -#else - /* Optimized version using 2 64-bit masks. */ - LLVMValueRef is_hi, is_0, c64, c128, all_bits; - LLVMValueRef prefix_mask[2], shift[2], mask_bcnt0, prefix_bcnt[2]; - - /* Compute the 128-bit prefix mask. */ - c64 = LLVMConstInt(ctx->i32, 64, 0); - c128 = LLVMConstInt(ctx->i32, 128, 0); - all_bits = LLVMConstInt(ctx->i64, UINT64_MAX, 0); - /* The first index that can have non-zero high bits in the prefix mask is 65. */ - is_hi = LLVMBuildICmp(builder, LLVMIntUGT, index, c64, ""); - is_0 = LLVMBuildICmp(builder, LLVMIntEQ, index, ctx->i32_0, ""); - mask_bcnt0 = ac_build_bit_count(ctx, mask[0]); - - for (unsigned i = 0; i < 2; i++) { - shift[i] = LLVMBuildSub(builder, i ? c128 : c64, index, ""); - /* For i==0, index==0, the right shift by 64 doesn't give the desired result, - * so we handle it by the is_0 select. - * For i==1, index==64, same story, so we handle it by the last is_hi select. - * For i==0, index==64, we shift by 0, which is what we want. - */ - prefix_mask[i] = - LLVMBuildLShr(builder, all_bits, LLVMBuildZExt(builder, shift[i], ctx->i64, ""), ""); - prefix_mask[i] = LLVMBuildAnd(builder, mask[i], prefix_mask[i], ""); - prefix_bcnt[i] = ac_build_bit_count(ctx, prefix_mask[i]); - } - - prefix_bcnt[0] = LLVMBuildSelect(builder, is_0, ctx->i32_0, prefix_bcnt[0], ""); - prefix_bcnt[0] = LLVMBuildSelect(builder, is_hi, mask_bcnt0, prefix_bcnt[0], ""); - prefix_bcnt[1] = LLVMBuildSelect(builder, is_hi, prefix_bcnt[1], ctx->i32_0, ""); - - return LLVMBuildAdd(builder, prefix_bcnt[0], prefix_bcnt[1], ""); -#endif -} - /** * Convert triangle strip indices to triangle indices. This is used to decompose * triangle strips into triangles. diff --git a/src/amd/llvm/ac_llvm_build.h b/src/amd/llvm/ac_llvm_build.h index de83d583cd1..41ee1033660 100644 --- a/src/amd/llvm/ac_llvm_build.h +++ b/src/amd/llvm/ac_llvm_build.h @@ -607,9 +607,6 @@ LLVMValueRef ac_build_main(const struct ac_shader_args *args, struct ac_llvm_con LLVMTypeRef ret_type, LLVMModuleRef module); void ac_build_s_endpgm(struct ac_llvm_context *ctx); -LLVMValueRef ac_prefix_bitcount(struct ac_llvm_context *ctx, LLVMValueRef mask, LLVMValueRef index); -LLVMValueRef ac_prefix_bitcount_2x64(struct ac_llvm_context *ctx, LLVMValueRef mask[2], - LLVMValueRef index); void ac_build_triangle_strip_indices_to_triangle(struct ac_llvm_context *ctx, LLVMValueRef is_odd, LLVMValueRef flatshade_first, LLVMValueRef index[3]); diff --git a/src/gallium/drivers/radeonsi/gfx10_shader_ngg.c b/src/gallium/drivers/radeonsi/gfx10_shader_ngg.c index ac8bdb58b3a..b1611704e57 100644 --- a/src/gallium/drivers/radeonsi/gfx10_shader_ngg.c +++ b/src/gallium/drivers/radeonsi/gfx10_shader_ngg.c @@ -577,15 +577,20 @@ enum lds_tes_patch_id, /* optional */ }; +static LLVMValueRef si_build_gep_i8_var(struct si_shader_context *ctx, LLVMValueRef ptr, + LLVMValueRef index) +{ + LLVMTypeRef pi8 = LLVMPointerType(ctx->ac.i8, AC_ADDR_SPACE_LDS); + + return LLVMBuildGEP(ctx->ac.builder, LLVMBuildPointerCast(ctx->ac.builder, ptr, pi8, ""), &index, + 1, ""); +} + static LLVMValueRef si_build_gep_i8(struct si_shader_context *ctx, LLVMValueRef ptr, unsigned byte_index) { assert(byte_index < 4); - LLVMTypeRef pi8 = LLVMPointerType(ctx->ac.i8, AC_ADDR_SPACE_LDS); - LLVMValueRef index = LLVMConstInt(ctx->ac.i32, byte_index, 0); - - return LLVMBuildGEP(ctx->ac.builder, LLVMBuildPointerCast(ctx->ac.builder, ptr, pi8, ""), &index, - 1, ""); + return si_build_gep_i8_var(ctx, ptr, LLVMConstInt(ctx->ac.i32, byte_index, 0)); } static unsigned ngg_nogs_vertex_size(struct si_shader *shader) @@ -653,42 +658,87 @@ static LLVMValueRef si_insert_input_v4i32(struct si_shader_context *ctx, LLVMVal return ret; } -static void load_bitmasks_2x64(struct si_shader_context *ctx, LLVMValueRef lds_ptr, - LLVMValueRef tid, - unsigned dw_offset, LLVMValueRef mask[4], - LLVMValueRef *total_bitcount) +static void load_vertex_counts(struct si_shader_context *ctx, LLVMValueRef lds, + unsigned max_waves, LLVMValueRef tid, + LLVMValueRef *total_count, + LLVMValueRef *prefix_sum) { LLVMBuilderRef builder = ctx->ac.builder; - LLVMValueRef ptr64 = LLVMBuildPointerCast( - builder, lds_ptr, LLVMPointerType(LLVMArrayType(ctx->ac.i64, 2), AC_ADDR_SPACE_LDS), ""); - LLVMValueRef tmp[2]; + LLVMValueRef i8vec4_lane = ac_build_alloca_undef(&ctx->ac, ctx->ac.i32, ""); + unsigned num_i8vec4 = DIV_ROUND_UP(max_waves, 4); - for (unsigned i = 0; i < 2; i++) - tmp[i] = ac_build_alloca_undef(&ctx->ac, ctx->ac.i64, ""); - - /* If all threads loaded the bitmasks, it would cause many LDS bank conflicts + /* If all threads loaded the vertex counts, it would cause many LDS bank conflicts * and the performance could decrease up to WaveSize times (32x or 64x). * - * Therefore, only load the bitmasks in thread 0 and other threads will get them - * through readlane. + * Therefore, only load the i-th tuple of vertex counts in the i-th thread. Other threads will + * get them through readlane. 4 8-bit vertex counts are loaded per thread. */ - ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntEQ, tid, ctx->ac.i32_0, ""), 17771); - for (unsigned i = 0; i < 2; i++) { - LLVMValueRef index = LLVMConstInt(ctx->ac.i32, dw_offset / 2 + i, 0); - LLVMValueRef val = LLVMBuildLoad(builder, ac_build_gep0(&ctx->ac, ptr64, index), ""); - LLVMBuildStore(builder, val, tmp[i]); - } + ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntULT, tid, + LLVMConstInt(ctx->ac.i32, num_i8vec4, 0), ""), 17771); + LLVMBuildStore(builder, LLVMBuildLoad(builder, ac_build_gep0(&ctx->ac, lds, tid), ""), i8vec4_lane); ac_build_endif(&ctx->ac, 17771); - *total_bitcount = ctx->ac.i32_0; + /* Compute the number of ES waves. */ + LLVMValueRef num_waves = get_tgsize(ctx); - for (unsigned i = 0; i < 2; i++) { - tmp[i] = LLVMBuildLoad(builder, tmp[i], ""); - mask[i] = ac_build_readlane_no_opt_barrier(&ctx->ac, tmp[i], NULL); + /* Compute a byte mask where each byte is either 0 or 0xff depending on whether the wave + * exists. We need the mask to clear uninitialized bytes in LDS and to compute the prefix sum. + * + * 8 waves: valid_mask = ~0ull >> (64 - num_waves * 8) + * 4 waves: valid_mask = ~0 >> (32 - num_waves * 8) + */ + LLVMValueRef num_waves8 = LLVMBuildShl(builder, num_waves, LLVMConstInt(ctx->ac.i32, 3, 0), ""); + LLVMValueRef valid_mask; - *total_bitcount = LLVMBuildAdd(builder, *total_bitcount, - ac_build_bit_count(&ctx->ac, mask[i]), ""); + if (max_waves > 4) { + LLVMValueRef num_waves8_rev = LLVMBuildSub(builder, LLVMConstInt(ctx->ac.i32, 64, 0), + num_waves8, ""); + valid_mask = LLVMBuildLShr(builder, LLVMConstInt(ctx->ac.i64, ~0ull, 0), + LLVMBuildZExt(builder, num_waves8_rev, ctx->ac.i64, ""), ""); + } else { + LLVMValueRef num_waves8_rev = LLVMBuildSub(builder, LLVMConstInt(ctx->ac.i32, 32, 0), + num_waves8, ""); + valid_mask = LLVMBuildLShr(builder, LLVMConstInt(ctx->ac.i32, ~0, 0), num_waves8_rev, ""); } + + /* Compute a byte mask where bytes below wave_id are 0xff, else they are 0. + * + * prefix_mask = ~(~0 << (wave_id * 8)) + */ + LLVMTypeRef type = max_waves > 4 ? ctx->ac.i64 : ctx->ac.i32; + LLVMValueRef wave_id8 = LLVMBuildShl(builder, get_wave_id_in_tg(ctx), + LLVMConstInt(ctx->ac.i32, 3, 0), ""); + LLVMValueRef prefix_mask = + LLVMBuildNot(builder, LLVMBuildShl(builder, LLVMConstInt(type, ~0ull, 0), + LLVMBuildZExt(builder, wave_id8, type, ""), ""), ""); + + /* Compute the total vertex count and the vertex count of previous waves (prefix). */ + *total_count = ctx->ac.i32_0; + *prefix_sum = ctx->ac.i32_0; + + for (unsigned i = 0; i < num_i8vec4; i++) { + LLVMValueRef i8vec4; + + i8vec4 = ac_build_readlane_no_opt_barrier(&ctx->ac, LLVMBuildLoad(builder, i8vec4_lane, ""), + LLVMConstInt(ctx->ac.i32, i, 0)); + /* Inactive waves have uninitialized vertex counts. Set them to 0 using this. */ + i8vec4 = LLVMBuildAnd(builder, i8vec4, + ac_unpack_param(&ctx->ac, valid_mask, 32 * i, 32), ""); + /* Compute the sum of all i8vec4 components and add it to the result. */ + *total_count = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.sad.u8", ctx->ac.i32, + (LLVMValueRef[]){i8vec4, ctx->ac.i32_0, *total_count}, + 3, AC_FUNC_ATTR_READNONE); + ac_set_range_metadata(&ctx->ac, *total_count, 0, 64*4 + 1); /* the result is at most 64*4 */ + + /* Compute the sum of the vertex counts of all previous waves. */ + i8vec4 = LLVMBuildAnd(builder, i8vec4, + ac_unpack_param(&ctx->ac, prefix_mask, 32 * i, 32), ""); + *prefix_sum = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.sad.u8", ctx->ac.i32, + (LLVMValueRef[]){i8vec4, ctx->ac.i32_0, *prefix_sum}, + 3, AC_FUNC_ATTR_READNONE); + ac_set_range_metadata(&ctx->ac, *prefix_sum, 0, 64*4 + 1); /* the result is at most 64*4 */ + } + *total_count = ac_build_readlane_no_opt_barrier(&ctx->ac, *total_count, NULL); } /** @@ -746,14 +796,9 @@ void gfx10_emit_ngg_culling_epilogue(struct ac_shader_abi *abi, unsigned max_out struct si_shader_selector *sel = shader->selector; struct si_shader_info *info = &sel->info; LLVMBuilderRef builder = ctx->ac.builder; - unsigned max_waves = ctx->ac.wave_size == 64 ? 2 : 4; - LLVMValueRef ngg_scratch = ctx->gs_ngg_scratch; - - if (ctx->ac.wave_size == 64) { - ngg_scratch = LLVMBuildPointerCast(builder, ngg_scratch, - LLVMPointerType(LLVMArrayType(ctx->ac.i64, max_waves), - AC_ADDR_SPACE_LDS), ""); - } + unsigned subgroup_size = 128; + unsigned max_waves = ctx->ac.wave_size == 64 ? DIV_ROUND_UP(subgroup_size, 64) : + DIV_ROUND_UP(subgroup_size, 32); assert(shader->key.opt.ngg_culling); assert(shader->key.as_ngg); @@ -805,48 +850,32 @@ void gfx10_emit_ngg_culling_epilogue(struct ac_shader_abi *abi, unsigned max_out builder, ctx->ac.i32_0, ac_build_gep0(&ctx->ac, es_vtxptr, LLVMConstInt(ctx->ac.i32, lds_packed_data, 0))); ac_build_endif(&ctx->ac, ctx->merged_wrap_if_label); + ac_build_s_barrier(&ctx->ac); LLVMValueRef tid = ac_get_thread_id(&ctx->ac); - /* Initialize all but the first element of ngg_scratch to 0, because we may have less - * than the maximum number of waves, but we always read all values. This is where - * the thread bitmasks of unculled threads will be stored. - * - * ngg_scratch layout: iN_wavemask esmask[0..n] - */ - ac_build_ifcc(&ctx->ac, - LLVMBuildICmp(builder, LLVMIntULT, get_thread_id_in_tg(ctx), - LLVMConstInt(ctx->ac.i32, max_waves - 1, 0), ""), - 16101); - { - LLVMValueRef index = LLVMBuildAdd(builder, tid, ctx->ac.i32_1, ""); - LLVMBuildStore(builder, LLVMConstInt(ctx->ac.iN_wavemask, 0, 0), - ac_build_gep0(&ctx->ac, ngg_scratch, index)); - } - ac_build_endif(&ctx->ac, 16101); - ac_build_s_barrier(&ctx->ac); - /* The hardware requires that there are no holes between unculled vertices, * which means we have to pack ES threads, i.e. reduce the ES thread count * and move ES input VGPRs to lower threads. The upside is that varyings * are only fetched and computed for unculled vertices. * - * Vertex compaction in GS threads: + * Vertex compaction: * - * Part 1: Compute the surviving vertex mask in GS threads: - * - Compute 4 32-bit surviving vertex masks in LDS. (max 4 waves) - * - In GS, notify ES threads whether the vertex survived. + * Part 1: Store the surviving vertex count for each wave in LDS. + * - The GS culling code notifies ES threads which vertices were accepted. * - Barrier - * - ES threads will create the mask and store it in LDS. + * - ES threads will compute the vertex count and store it in LDS. * - Barrier - * - Each GS thread loads the vertex masks from LDS. + * - Each wave loads the vertex counts from LDS. * - * Part 2: Compact ES threads in GS threads: - * - Compute the prefix sum for all 3 vertices from the masks. These are the new - * thread IDs for each vertex within the primitive. - * - Write input VGPRs and vertex positions into the LDS address of the new thread ID. - * - Update vertex indices and null flag in the GS input VGPRs. + * Part 2: Compact ES threads: + * - Compute the prefix sum for each surviving vertex. This is the new thread ID + * of the vertex. + * - Write input VGPRs and vertex positions for each surviving vertex into the LDS + * address of the new thread ID. + * - Now kill all waves that have inactive threads. * - Barrier + * - Update vertex indices and null flag in the GS input VGPRs. * * Part 3: Update inputs GPRs * - For all waves, update per-wave thread counts in input SGPRs. @@ -972,33 +1001,38 @@ void gfx10_emit_ngg_culling_epilogue(struct ac_shader_abi *abi, unsigned max_out gs_accepted = LLVMBuildLoad(builder, gs_accepted, ""); - LLVMValueRef es_accepted = ac_build_alloca(&ctx->ac, ctx->ac.i1, ""); + LLVMValueRef vertex_accepted = ac_build_alloca(&ctx->ac, ctx->ac.i1, ""); + LLVMValueRef vertex_mask = ac_build_alloca(&ctx->ac, ctx->ac.iN_wavemask, ""); - /* Convert the per-vertex flag to a thread bitmask in ES threads and store it in LDS. */ + /* Convert the per-vertex accept flag to a vertex thread mask, store it in registers. */ ac_build_ifcc(&ctx->ac, si_is_es_thread(ctx), 16007); { - LLVMValueRef es_accepted_flag = + LLVMValueRef accepted = LLVMBuildLoad(builder, si_build_gep_i8(ctx, es_vtxptr, lds_byte0_accept_flag), ""); + accepted = LLVMBuildICmp(builder, LLVMIntNE, accepted, ctx->ac.i8_0, ""); + LLVMValueRef mask = ac_get_i1_sgpr_mask(&ctx->ac, accepted); - LLVMValueRef es_accepted_bool = - LLVMBuildICmp(builder, LLVMIntNE, es_accepted_flag, ctx->ac.i8_0, ""); - LLVMValueRef es_mask = ac_get_i1_sgpr_mask(&ctx->ac, es_accepted_bool); - - LLVMBuildStore(builder, es_accepted_bool, es_accepted); - - ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntEQ, tid, ctx->ac.i32_0, ""), 16008); - { - LLVMBuildStore(builder, es_mask, - ac_build_gep0(&ctx->ac, ngg_scratch, get_wave_id_in_tg(ctx))); - } - ac_build_endif(&ctx->ac, 16008); + LLVMBuildStore(builder, accepted, vertex_accepted); + LLVMBuildStore(builder, mask, vertex_mask); } ac_build_endif(&ctx->ac, 16007); + + /* Store the per-wave vertex count to LDS. Non-ES waves store 0. */ + vertex_mask = LLVMBuildLoad(builder, vertex_mask, ""); + ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntEQ, tid, ctx->ac.i32_0, ""), 16008); + { + LLVMValueRef vertex_count = ac_build_bit_count(&ctx->ac, vertex_mask); + LLVMBuildStore(builder, LLVMBuildTrunc(builder, vertex_count, ctx->ac.i8, ""), + si_build_gep_i8_var(ctx, ctx->gs_ngg_scratch, get_wave_id_in_tg(ctx))); + } + ac_build_endif(&ctx->ac, 16008); + ac_build_s_barrier(&ctx->ac); /* Load the vertex masks and compute the new ES thread count. */ - LLVMValueRef es_mask[2], new_num_es_threads, kill_wave; - load_bitmasks_2x64(ctx, ngg_scratch, tid, 0, es_mask, &new_num_es_threads); + LLVMValueRef new_num_es_threads, prefix_sum, kill_wave; + load_vertex_counts(ctx, ctx->gs_ngg_scratch, max_waves, tid, &new_num_es_threads, + &prefix_sum); bool uses_instance_id = ctx->stage == MESA_SHADER_VERTEX && (sel->info.uses_instanceid || @@ -1012,10 +1046,15 @@ void gfx10_emit_ngg_culling_epilogue(struct ac_shader_abi *abi, unsigned max_out * of the new thread ID. It will be used to load input VGPRs by compacted * threads. */ - ac_build_ifcc(&ctx->ac, LLVMBuildLoad(builder, es_accepted, ""), 16009); + vertex_accepted = LLVMBuildLoad(builder, vertex_accepted, ""); + ac_build_ifcc(&ctx->ac, vertex_accepted, 16009); { - LLVMValueRef old_id = get_thread_id_in_tg(ctx); - LLVMValueRef new_id = ac_prefix_bitcount_2x64(&ctx->ac, es_mask, old_id); + /* Add the number of bits set in vertex_mask up to the current thread ID - 1 + * to get the prefix sum. + */ + prefix_sum = LLVMBuildAdd(builder, prefix_sum, ac_build_mbcnt(&ctx->ac, vertex_mask), ""); + + LLVMValueRef new_id = prefix_sum; LLVMValueRef new_vtx = ngg_nogs_vertex_ptr(ctx, new_id); LLVMBuildStore(builder, LLVMBuildTrunc(builder, new_id, ctx->ac.i8, ""),