amd: Add extra source to the mbcnt_amd NIR intrinsic.

The v_mbcnt instructions can take an extra source that they add to
the result. This is not exposed in SPIR-V but we now expose it in NIR.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Tony Wasserka <tony.wasserka@gmx.de>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/11072>
This commit is contained in:
Timur Kristóf
2021-06-09 11:00:22 +02:00
committed by Marge Bot
parent f6b2db298f
commit 1e49018ced
8 changed files with 30 additions and 9 deletions

View File

@@ -97,7 +97,7 @@ repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool,
if (max_num_waves == 1) {
wg_repack_result r = {
.num_repacked_invocations = surviving_invocations_in_current_wave,
.repacked_invocation_index = nir_build_mbcnt_amd(b, input_mask),
.repacked_invocation_index = nir_build_mbcnt_amd(b, input_mask, nir_imm_int(b, 0)),
};
return r;
}
@@ -182,10 +182,9 @@ repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool,
unreachable("Unimplemented NGG wave count");
}
nir_ssa_def *wave_repacked_index = nir_build_mbcnt_amd(b, input_mask);
nir_ssa_def *wg_repacked_index_base = nir_build_read_invocation(b, sum, wave_id);
nir_ssa_def *wg_num_repacked_invocations = nir_build_read_invocation(b, sum, num_waves);
nir_ssa_def *wg_repacked_index = nir_iadd_nuw(b, wg_repacked_index_base, wave_repacked_index);
nir_ssa_def *wg_repacked_index = nir_build_mbcnt_amd(b, input_mask, wg_repacked_index_base);
wg_repack_result r = {
.num_repacked_invocations = wg_num_repacked_invocations,

View File

@@ -8380,10 +8380,11 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
}
case nir_intrinsic_mbcnt_amd: {
Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
Temp add_src = get_ssa_temp(ctx, instr->src[1].ssa);
Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
/* Fit 64-bit mask for wave32 */
src = emit_extract_vector(ctx, src, 0, RegClass(src.type(), bld.lm.size()));
Temp wqm_tmp = emit_mbcnt(ctx, bld.tmp(v1), Operand(src));
Temp wqm_tmp = emit_mbcnt(ctx, bld.tmp(v1), Operand(src), Operand(add_src));
emit_wqm(bld, wqm_tmp, dst);
break;
}

View File

@@ -3452,7 +3452,7 @@ LLVMValueRef ac_build_writelane(struct ac_llvm_context *ctx, LLVMValueRef src, L
AC_FUNC_ATTR_READNONE | AC_FUNC_ATTR_CONVERGENT);
}
LLVMValueRef ac_build_mbcnt(struct ac_llvm_context *ctx, LLVMValueRef mask)
LLVMValueRef ac_build_mbcnt_add(struct ac_llvm_context *ctx, LLVMValueRef mask, LLVMValueRef add_src)
{
if (ctx->wave_size == 32) {
LLVMValueRef val = ac_build_intrinsic(ctx, "llvm.amdgcn.mbcnt.lo", ctx->i32,
@@ -3465,13 +3465,18 @@ LLVMValueRef ac_build_mbcnt(struct ac_llvm_context *ctx, LLVMValueRef mask)
LLVMValueRef mask_hi = LLVMBuildExtractElement(ctx->builder, mask_vec, ctx->i32_1, "");
LLVMValueRef val =
ac_build_intrinsic(ctx, "llvm.amdgcn.mbcnt.lo", ctx->i32,
(LLVMValueRef[]){mask_lo, ctx->i32_0}, 2, AC_FUNC_ATTR_READNONE);
(LLVMValueRef[]){mask_lo, add_src}, 2, AC_FUNC_ATTR_READNONE);
val = ac_build_intrinsic(ctx, "llvm.amdgcn.mbcnt.hi", ctx->i32, (LLVMValueRef[]){mask_hi, val},
2, AC_FUNC_ATTR_READNONE);
ac_set_range_metadata(ctx, val, 0, ctx->wave_size);
return val;
}
LLVMValueRef ac_build_mbcnt(struct ac_llvm_context *ctx, LLVMValueRef mask)
{
return ac_build_mbcnt_add(ctx, mask, ctx->i32_0);
}
enum dpp_ctrl
{
_dpp_quad_perm = 0x000,

View File

@@ -504,6 +504,7 @@ LLVMValueRef ac_build_readlane(struct ac_llvm_context *ctx, LLVMValueRef src, LL
LLVMValueRef ac_build_writelane(struct ac_llvm_context *ctx, LLVMValueRef src, LLVMValueRef value,
LLVMValueRef lane);
LLVMValueRef ac_build_mbcnt_add(struct ac_llvm_context *ctx, LLVMValueRef mask, LLVMValueRef add_src);
LLVMValueRef ac_build_mbcnt(struct ac_llvm_context *ctx, LLVMValueRef mask);
LLVMValueRef ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op);

View File

@@ -3924,7 +3924,7 @@ static void visit_intrinsic(struct ac_nir_context *ctx, nir_intrinsic_instr *ins
get_src(ctx, instr->src[1]), get_src(ctx, instr->src[2]));
break;
case nir_intrinsic_mbcnt_amd:
result = ac_build_mbcnt(&ctx->ac, get_src(ctx, instr->src[0]));
result = ac_build_mbcnt_add(&ctx->ac, get_src(ctx, instr->src[0]), get_src(ctx, instr->src[1]));
break;
case nir_intrinsic_load_scratch: {
LLVMValueRef offset = get_src(ctx, instr->src[0]);

View File

@@ -428,7 +428,8 @@ intrinsic("masked_swizzle_amd", src_comp=[0], dest_comp=0, bit_sizes=src0,
indices=[SWIZZLE_MASK], flags=[CAN_ELIMINATE])
intrinsic("write_invocation_amd", src_comp=[0, 0, 1], dest_comp=0, bit_sizes=src0,
flags=[CAN_ELIMINATE])
intrinsic("mbcnt_amd", src_comp=[1], dest_comp=1, bit_sizes=[32], flags=[CAN_ELIMINATE])
# src = [ mask, addition ]
intrinsic("mbcnt_amd", src_comp=[1, 1], dest_comp=1, bit_sizes=[32], flags=[CAN_ELIMINATE])
# Compiled to v_perm_b32. src = [ in_bytes_hi, in_bytes_lo, selector ]
intrinsic("byte_permute_amd", src_comp=[1, 1, 1], dest_comp=1, bit_sizes=[32], flags=[CAN_ELIMINATE, CAN_REORDER])
# Compiled to v_permlane16_b32. src = [ value, lanesel_lo, lanesel_hi ]

View File

@@ -1329,9 +1329,18 @@ nir_unsigned_upper_bound(nir_shader *shader, struct hash_table *range_ht,
break;
case nir_intrinsic_load_subgroup_invocation:
case nir_intrinsic_first_invocation:
case nir_intrinsic_mbcnt_amd:
res = config->max_subgroup_size - 1;
break;
case nir_intrinsic_mbcnt_amd: {
uint32_t src0 = config->max_subgroup_size - 1;
uint32_t src1 = nir_unsigned_upper_bound(shader, range_ht, (nir_ssa_scalar){intrin->src[1].ssa, 0}, config);
if (src0 + src1 < src0)
res = max; /* overflow */
else
res = src0 + src1;
break;
}
case nir_intrinsic_load_subgroup_size:
res = config->max_subgroup_size;
break;

View File

@@ -101,6 +101,11 @@ vtn_handle_amd_shader_ballot_instruction(struct vtn_builder *b, SpvOp ext_opcode
val->constant->values[1].u32 << 5 |
val->constant->values[2].u32 << 10;
nir_intrinsic_set_swizzle_mask(intrin, mask);
} else if (intrin->intrinsic == nir_intrinsic_mbcnt_amd) {
/* The v_mbcnt instruction has an additional source that is added to the result.
* This is exposed by the NIR intrinsic but not by SPIR-V, so we add zero here.
*/
intrin->src[1] = nir_src_for_ssa(nir_imm_int(&b->nb, 0));
}
nir_builder_instr_insert(&b->nb, &intrin->instr);