amd/common: scan/reduce across waves of a workgroup
Order-aware scan/reduce can trade-off LDS traffic for external atomics memory traffic in producer/consumer compute shaders. Reviewed-by: Marek Olšák <marek.olsak@amd.com>
This commit is contained in:
@@ -3112,24 +3112,44 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs,
|
||||
}
|
||||
}
|
||||
|
||||
/* TODO: add inclusive and excluse scan functions for SI chip class. */
|
||||
/**
|
||||
* \param maxprefix specifies that the result only needs to be correct for a
|
||||
* prefix of this many threads
|
||||
*
|
||||
* TODO: add inclusive and excluse scan functions for SI chip class.
|
||||
*/
|
||||
static LLVMValueRef
|
||||
ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity)
|
||||
ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity,
|
||||
unsigned maxprefix)
|
||||
{
|
||||
LLVMValueRef result, tmp;
|
||||
result = src;
|
||||
if (maxprefix <= 1)
|
||||
return result;
|
||||
tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(1), 0xf, 0xf, false);
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
if (maxprefix <= 2)
|
||||
return result;
|
||||
tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(2), 0xf, 0xf, false);
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
if (maxprefix <= 3)
|
||||
return result;
|
||||
tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(3), 0xf, 0xf, false);
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
if (maxprefix <= 4)
|
||||
return result;
|
||||
tmp = ac_build_dpp(ctx, identity, result, dpp_row_sr(4), 0xf, 0xe, false);
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
if (maxprefix <= 8)
|
||||
return result;
|
||||
tmp = ac_build_dpp(ctx, identity, result, dpp_row_sr(8), 0xf, 0xc, false);
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
if (maxprefix <= 16)
|
||||
return result;
|
||||
tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast15, 0xa, 0xf, false);
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
if (maxprefix <= 32)
|
||||
return result;
|
||||
tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false);
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
return result;
|
||||
@@ -3144,7 +3164,7 @@ ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op
|
||||
get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src)));
|
||||
result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity),
|
||||
LLVMTypeOf(identity), "");
|
||||
result = ac_build_scan(ctx, op, result, identity);
|
||||
result = ac_build_scan(ctx, op, result, identity, 64);
|
||||
|
||||
return ac_build_wwm(ctx, result);
|
||||
}
|
||||
@@ -3159,7 +3179,7 @@ ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op
|
||||
result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity),
|
||||
LLVMTypeOf(identity), "");
|
||||
result = ac_build_dpp(ctx, identity, result, dpp_wf_sr1, 0xf, 0xf, false);
|
||||
result = ac_build_scan(ctx, op, result, identity);
|
||||
result = ac_build_scan(ctx, op, result, identity, 64);
|
||||
|
||||
return ac_build_wwm(ctx, result);
|
||||
}
|
||||
@@ -3217,6 +3237,173 @@ ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsign
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* "Top half" of a scan that reduces per-wave values across an entire
|
||||
* workgroup.
|
||||
*
|
||||
* The source value must be present in the highest lane of the wave, and the
|
||||
* highest lane must be live.
|
||||
*/
|
||||
void
|
||||
ac_build_wg_wavescan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
|
||||
{
|
||||
if (ws->maxwaves <= 1)
|
||||
return;
|
||||
|
||||
const LLVMValueRef i32_63 = LLVMConstInt(ctx->i32, 63, false);
|
||||
LLVMBuilderRef builder = ctx->builder;
|
||||
LLVMValueRef tid = ac_get_thread_id(ctx);
|
||||
LLVMValueRef tmp;
|
||||
|
||||
tmp = LLVMBuildICmp(builder, LLVMIntEQ, tid, i32_63, "");
|
||||
ac_build_ifcc(ctx, tmp, 1000);
|
||||
LLVMBuildStore(builder, ws->src, LLVMBuildGEP(builder, ws->scratch, &ws->waveidx, 1, ""));
|
||||
ac_build_endif(ctx, 1000);
|
||||
}
|
||||
|
||||
/**
|
||||
* "Bottom half" of a scan that reduces per-wave values across an entire
|
||||
* workgroup.
|
||||
*
|
||||
* The caller must place a barrier between the top and bottom halves.
|
||||
*/
|
||||
void
|
||||
ac_build_wg_wavescan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
|
||||
{
|
||||
const LLVMTypeRef type = LLVMTypeOf(ws->src);
|
||||
const LLVMValueRef identity =
|
||||
get_reduction_identity(ctx, ws->op, ac_get_type_size(type));
|
||||
|
||||
if (ws->maxwaves <= 1) {
|
||||
ws->result_reduce = ws->src;
|
||||
ws->result_inclusive = ws->src;
|
||||
ws->result_exclusive = identity;
|
||||
return;
|
||||
}
|
||||
assert(ws->maxwaves <= 32);
|
||||
|
||||
LLVMBuilderRef builder = ctx->builder;
|
||||
LLVMValueRef tid = ac_get_thread_id(ctx);
|
||||
LLVMBasicBlockRef bbs[2];
|
||||
LLVMValueRef phivalues_scan[2];
|
||||
LLVMValueRef tmp, tmp2;
|
||||
|
||||
bbs[0] = LLVMGetInsertBlock(builder);
|
||||
phivalues_scan[0] = LLVMGetUndef(type);
|
||||
|
||||
if (ws->enable_reduce)
|
||||
tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, ws->numwaves, "");
|
||||
else if (ws->enable_inclusive)
|
||||
tmp = LLVMBuildICmp(builder, LLVMIntULE, tid, ws->waveidx, "");
|
||||
else
|
||||
tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, ws->waveidx, "");
|
||||
ac_build_ifcc(ctx, tmp, 1001);
|
||||
{
|
||||
tmp = LLVMBuildLoad(builder, LLVMBuildGEP(builder, ws->scratch, &tid, 1, ""), "");
|
||||
|
||||
ac_build_optimization_barrier(ctx, &tmp);
|
||||
|
||||
bbs[1] = LLVMGetInsertBlock(builder);
|
||||
phivalues_scan[1] = ac_build_scan(ctx, ws->op, tmp, identity, ws->maxwaves);
|
||||
}
|
||||
ac_build_endif(ctx, 1001);
|
||||
|
||||
const LLVMValueRef scan = ac_build_phi(ctx, type, 2, phivalues_scan, bbs);
|
||||
|
||||
if (ws->enable_reduce) {
|
||||
tmp = LLVMBuildSub(builder, ws->numwaves, ctx->i32_1, "");
|
||||
ws->result_reduce = ac_build_readlane(ctx, scan, tmp);
|
||||
}
|
||||
if (ws->enable_inclusive)
|
||||
ws->result_inclusive = ac_build_readlane(ctx, scan, ws->waveidx);
|
||||
if (ws->enable_exclusive) {
|
||||
tmp = LLVMBuildSub(builder, ws->waveidx, ctx->i32_1, "");
|
||||
tmp = ac_build_readlane(ctx, scan, tmp);
|
||||
tmp2 = LLVMBuildICmp(builder, LLVMIntEQ, ws->waveidx, ctx->i32_0, "");
|
||||
ws->result_exclusive = LLVMBuildSelect(builder, tmp2, identity, tmp, "");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Inclusive scan of a per-wave value across an entire workgroup.
|
||||
*
|
||||
* This implies an s_barrier instruction.
|
||||
*
|
||||
* Unlike ac_build_inclusive_scan, the caller \em must ensure that all threads
|
||||
* of the workgroup are live. (This requirement cannot easily be relaxed in a
|
||||
* useful manner because of the barrier in the algorithm.)
|
||||
*/
|
||||
void
|
||||
ac_build_wg_wavescan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
|
||||
{
|
||||
ac_build_wg_wavescan_top(ctx, ws);
|
||||
ac_build_s_barrier(ctx);
|
||||
ac_build_wg_wavescan_bottom(ctx, ws);
|
||||
}
|
||||
|
||||
/**
|
||||
* "Top half" of a scan that reduces per-thread values across an entire
|
||||
* workgroup.
|
||||
*
|
||||
* All lanes must be active when this code runs.
|
||||
*/
|
||||
void
|
||||
ac_build_wg_scan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
|
||||
{
|
||||
if (ws->enable_exclusive) {
|
||||
ws->extra = ac_build_exclusive_scan(ctx, ws->src, ws->op);
|
||||
ws->src = ac_build_alu_op(ctx, ws->extra, ws->src, ws->op);
|
||||
} else {
|
||||
ws->src = ac_build_inclusive_scan(ctx, ws->src, ws->op);
|
||||
}
|
||||
|
||||
bool enable_inclusive = ws->enable_inclusive;
|
||||
bool enable_exclusive = ws->enable_exclusive;
|
||||
ws->enable_inclusive = false;
|
||||
ws->enable_exclusive = ws->enable_exclusive || enable_inclusive;
|
||||
ac_build_wg_wavescan_top(ctx, ws);
|
||||
ws->enable_inclusive = enable_inclusive;
|
||||
ws->enable_exclusive = enable_exclusive;
|
||||
}
|
||||
|
||||
/**
|
||||
* "Bottom half" of a scan that reduces per-thread values across an entire
|
||||
* workgroup.
|
||||
*
|
||||
* The caller must place a barrier between the top and bottom halves.
|
||||
*/
|
||||
void
|
||||
ac_build_wg_scan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
|
||||
{
|
||||
bool enable_inclusive = ws->enable_inclusive;
|
||||
bool enable_exclusive = ws->enable_exclusive;
|
||||
ws->enable_inclusive = false;
|
||||
ws->enable_exclusive = ws->enable_exclusive || enable_inclusive;
|
||||
ac_build_wg_wavescan_bottom(ctx, ws);
|
||||
ws->enable_inclusive = enable_inclusive;
|
||||
ws->enable_exclusive = enable_exclusive;
|
||||
|
||||
/* ws->result_reduce is already the correct value */
|
||||
if (ws->enable_inclusive)
|
||||
ws->result_inclusive = ac_build_alu_op(ctx, ws->result_exclusive, ws->src, ws->op);
|
||||
if (ws->enable_exclusive)
|
||||
ws->result_exclusive = ac_build_alu_op(ctx, ws->result_exclusive, ws->extra, ws->op);
|
||||
}
|
||||
|
||||
/**
|
||||
* A scan that reduces per-thread values across an entire workgroup.
|
||||
*
|
||||
* The caller must ensure that all lanes are active when this code runs
|
||||
* (WWM is insufficient!), because there is an implied barrier.
|
||||
*/
|
||||
void
|
||||
ac_build_wg_scan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
|
||||
{
|
||||
ac_build_wg_scan_top(ctx, ws);
|
||||
ac_build_s_barrier(ctx);
|
||||
ac_build_wg_scan_bottom(ctx, ws);
|
||||
}
|
||||
|
||||
LLVMValueRef
|
||||
ac_build_quad_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src,
|
||||
unsigned lane0, unsigned lane1, unsigned lane2, unsigned lane3)
|
||||
|
@@ -524,6 +524,42 @@ ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op
|
||||
LLVMValueRef
|
||||
ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsigned cluster_size);
|
||||
|
||||
/**
|
||||
* Common arguments for a scan/reduce operation that accumulates per-wave
|
||||
* values across an entire workgroup, while respecting the order of waves.
|
||||
*/
|
||||
struct ac_wg_scan {
|
||||
bool enable_reduce;
|
||||
bool enable_exclusive;
|
||||
bool enable_inclusive;
|
||||
nir_op op;
|
||||
LLVMValueRef src; /* clobbered! */
|
||||
LLVMValueRef result_reduce;
|
||||
LLVMValueRef result_exclusive;
|
||||
LLVMValueRef result_inclusive;
|
||||
LLVMValueRef extra;
|
||||
LLVMValueRef waveidx;
|
||||
LLVMValueRef numwaves; /* only needed for "reduce" operations */
|
||||
|
||||
/* T addrspace(LDS) pointer to the same type as value, at least maxwaves entries */
|
||||
LLVMValueRef scratch;
|
||||
unsigned maxwaves;
|
||||
};
|
||||
|
||||
void
|
||||
ac_build_wg_wavescan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
|
||||
void
|
||||
ac_build_wg_wavescan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
|
||||
void
|
||||
ac_build_wg_wavescan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
|
||||
|
||||
void
|
||||
ac_build_wg_scan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
|
||||
void
|
||||
ac_build_wg_scan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
|
||||
void
|
||||
ac_build_wg_scan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws);
|
||||
|
||||
LLVMValueRef
|
||||
ac_build_quad_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src,
|
||||
unsigned lane0, unsigned lane1, unsigned lane2, unsigned lane3);
|
||||
|
Reference in New Issue
Block a user