amd: lower load_tess_rel_patch_id/primitive_id/tess_coord and overwrite.. in NIR

The overwrite instruction complicates it a little, which is why these
intrinsics are lowered together.

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32782>
This commit is contained in:
Marek Olšák
2024-12-25 12:44:59 -05:00
committed by Marge Bot
parent 61bfb4fa06
commit ceb6f8fc32
7 changed files with 62 additions and 148 deletions

View File

@@ -106,6 +106,10 @@ typedef struct {
nir_def *vertex_id;
nir_def *instance_id;
nir_def *vs_rel_patch_id;
nir_def *tes_u;
nir_def *tes_v;
nir_def *tes_patch_id;
nir_def *tes_rel_patch_id;
} lower_intrinsics_to_args_state;
static nir_def *
@@ -368,6 +372,13 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
s->instance_id = intrin->src[1].ssa;
nir_instr_remove(instr);
return true;
case nir_intrinsic_overwrite_tes_arguments_amd:
s->tes_u = intrin->src[0].ssa;
s->tes_v = intrin->src[1].ssa;
s->tes_patch_id = intrin->src[2].ssa;
s->tes_rel_patch_id = intrin->src[3].ssa;
nir_instr_remove(instr);
return true;
case nir_intrinsic_load_vertex_id_zero_base:
if (!s->vertex_id)
s->vertex_id = preload_arg(s, b->impl, s->args->vertex_id, s->args->tcs_patch_id);
@@ -378,6 +389,57 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state)
s->instance_id = preload_arg(s, b->impl, s->args->instance_id, s->args->vertex_id);
replacement = s->instance_id;
break;
case nir_intrinsic_load_tess_rel_patch_id_amd:
if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
replacement = ac_nir_unpack_arg(b, s->args, s->args->tcs_rel_ids, 0, 8);
} else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
if (s->tes_rel_patch_id) {
replacement = s->tes_rel_patch_id;
} else {
replacement = ac_nir_load_arg(b, s->args, s->args->tes_rel_patch_id);
if (b->shader->info.tess.tcs_vertices_out) {
/* Setting an upper bound like this will actually make it possible
* to optimize some multiplications (in address calculations) so that
* constant additions can be added to the const offset in memory load instructions.
*/
nir_intrinsic_set_arg_upper_bound_u32_amd(nir_instr_as_intrinsic(replacement->parent_instr),
2048 / b->shader->info.tess.tcs_vertices_out);
}
}
} else {
unreachable("invalid stage");
}
break;
case nir_intrinsic_load_primitive_id:
if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
replacement = ac_nir_load_arg(b, s->args, s->args->gs_prim_id);
} else if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
replacement = ac_nir_load_arg(b, s->args, s->args->tcs_patch_id);
} else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
replacement = s->tes_patch_id ? s->tes_patch_id :
ac_nir_load_arg(b, s->args, s->args->tes_patch_id);
} else if (b->shader->info.stage == MESA_SHADER_VERTEX) {
if (s->hw_stage == AC_HW_VERTEX_SHADER)
replacement = ac_nir_load_arg(b, s->args, s->args->vs_prim_id); /* legacy */
else
replacement = ac_nir_load_arg(b, s->args, s->args->gs_prim_id); /* NGG */
} else {
unreachable("invalid stage");
}
break;
case nir_intrinsic_load_tess_coord: {
nir_def *coord[3] = {
s->tes_u ? s->tes_u : ac_nir_load_arg(b, s->args, s->args->tes_u),
s->tes_v ? s->tes_v : ac_nir_load_arg(b, s->args, s->args->tes_v),
nir_imm_float(b, 0),
};
/* For triangles, the vector should be (u, v, 1-u-v). */
if (b->shader->info.tess._primitive_mode == TESS_PRIMITIVE_TRIANGLES)
coord[2] = nir_fsub(b, nir_imm_float(b, 1), nir_fadd(b, coord[0], coord[1]));
replacement = nir_vec(b, coord, 3);
break;
}
case nir_intrinsic_load_local_invocation_index:
/* GFX11 HS has subgroup_id, so use it instead of vs_rel_patch_id. */
if (s->gfx_level < GFX11 &&

View File

@@ -5688,28 +5688,6 @@ visit_load_per_vertex_input(isel_context* ctx, nir_intrinsic_instr* instr)
}
}
void
visit_load_tess_coord(isel_context* ctx, nir_intrinsic_instr* instr)
{
assert(ctx->shader->info.stage == MESA_SHADER_TESS_EVAL);
Builder bld(ctx->program, ctx->block);
Temp dst = get_ssa_temp(ctx, &instr->def);
Operand tes_u(get_arg(ctx, ctx->args->tes_u));
Operand tes_v(get_arg(ctx, ctx->args->tes_v));
Operand tes_w = Operand::zero();
if (ctx->shader->info.tess._primitive_mode == TESS_PRIMITIVE_TRIANGLES) {
Temp tmp = bld.vop2(aco_opcode::v_add_f32, bld.def(v1), tes_u, tes_v);
tmp = bld.vop2(aco_opcode::v_sub_f32, bld.def(v1), Operand::c32(0x3f800000u /* 1.0f */), tmp);
tes_w = Operand(tmp);
}
Temp tess_coord = bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tes_u, tes_v, tes_w);
emit_split_vector(ctx, tess_coord, 3);
}
ac_hw_cache_flags
get_cache_flags(isel_context* ctx, unsigned access)
{
@@ -7970,7 +7948,6 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
{
Builder bld(ctx->program, ctx->block);
switch (instr->intrinsic) {
case nir_intrinsic_load_tess_coord: visit_load_tess_coord(ctx, instr); break;
case nir_intrinsic_load_interpolated_input: visit_load_interpolated_input(ctx, instr); break;
case nir_intrinsic_store_output: visit_store_output(ctx, instr); break;
case nir_intrinsic_load_input:
@@ -8768,34 +8745,6 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
emit_split_vector(ctx, dst, 2);
break;
}
case nir_intrinsic_load_primitive_id: {
Temp dst = get_ssa_temp(ctx, &instr->def);
switch (ctx->shader->info.stage) {
case MESA_SHADER_GEOMETRY:
bld.copy(Definition(dst), get_arg(ctx, ctx->args->gs_prim_id));
break;
case MESA_SHADER_TESS_CTRL:
bld.copy(Definition(dst), get_arg(ctx, ctx->args->tcs_patch_id));
break;
case MESA_SHADER_TESS_EVAL:
bld.copy(Definition(dst), get_arg(ctx, ctx->args->tes_patch_id));
break;
default:
if (ctx->stage.hw == AC_HW_NEXT_GEN_GEOMETRY_SHADER && !ctx->stage.has(SWStage::GS)) {
/* In case of NGG, the GS threads always have the primitive ID
* even if there is no SW GS. */
bld.copy(Definition(dst), get_arg(ctx, ctx->args->gs_prim_id));
break;
} else if (ctx->shader->info.stage == MESA_SHADER_VERTEX) {
bld.copy(Definition(dst), get_arg(ctx, ctx->args->vs_prim_id));
break;
}
unreachable("Unimplemented shader stage for nir_intrinsic_load_primitive_id");
}
break;
}
case nir_intrinsic_sendmsg_amd: {
unsigned imm = nir_intrinsic_base(instr);
Temp m0_content = bld.as_uniform(get_ssa_temp(ctx, instr->src[0].ssa));
@@ -8830,13 +8779,6 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
bld.def(s1, scc), Operand::c32(nir_intrinsic_call_idx(instr)));
break;
}
case nir_intrinsic_overwrite_tes_arguments_amd: {
ctx->arg_temps[ctx->args->tes_u.arg_index] = get_ssa_temp(ctx, instr->src[0].ssa);
ctx->arg_temps[ctx->args->tes_v.arg_index] = get_ssa_temp(ctx, instr->src[1].ssa);
ctx->arg_temps[ctx->args->tes_rel_patch_id.arg_index] = get_ssa_temp(ctx, instr->src[3].ssa);
ctx->arg_temps[ctx->args->tes_patch_id.arg_index] = get_ssa_temp(ctx, instr->src[2].ssa);
break;
}
case nir_intrinsic_load_scalar_arg_amd:
case nir_intrinsic_load_vector_arg_amd: {
assert(nir_intrinsic_base(instr) < ctx->args->arg_count);

View File

@@ -550,7 +550,6 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_intrinsic_load_per_vertex_input:
case nir_intrinsic_load_per_vertex_output:
case nir_intrinsic_load_interpolated_input:
case nir_intrinsic_load_tess_coord:
case nir_intrinsic_write_invocation_amd:
case nir_intrinsic_mbcnt_amd:
case nir_intrinsic_lane_permute_16_amd:
@@ -565,7 +564,6 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_intrinsic_shared_atomic:
case nir_intrinsic_shared_atomic_swap:
case nir_intrinsic_load_scratch:
case nir_intrinsic_load_primitive_id:
case nir_intrinsic_load_typed_buffer_amd:
case nir_intrinsic_load_buffer_amd:
case nir_intrinsic_load_initial_edgeflags_amd:

View File

@@ -2791,40 +2791,11 @@ static bool visit_intrinsic(struct ac_nir_context *ctx, nir_intrinsic_instr *ins
result = ac_build_gather_values(&ctx->ac, values, 3);
break;
}
case nir_intrinsic_load_tess_rel_patch_id_amd:
switch (ctx->stage) {
case MESA_SHADER_TESS_CTRL:
result = ac_unpack_param(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->tcs_rel_ids), 0, 8);
break;
case MESA_SHADER_TESS_EVAL:
result = ctx->abi->tes_rel_patch_id_replaced ? ctx->abi->tes_rel_patch_id_replaced :
ac_get_arg(&ctx->ac, ctx->args->tes_rel_patch_id);
break;
default:
unreachable("invalid stage");
}
break;
case nir_intrinsic_load_ring_attr_amd:
case nir_intrinsic_load_lds_ngg_scratch_base_amd:
case nir_intrinsic_load_lds_ngg_gs_out_vertex_base_amd:
result = ctx->abi->intrinsic_load(ctx->abi, instr);
break;
case nir_intrinsic_load_primitive_id:
if (ctx->stage == MESA_SHADER_GEOMETRY) {
result = ac_get_arg(&ctx->ac, ctx->args->gs_prim_id);
} else if (ctx->stage == MESA_SHADER_TESS_CTRL) {
result = ac_get_arg(&ctx->ac, ctx->args->tcs_patch_id);
} else if (ctx->stage == MESA_SHADER_TESS_EVAL) {
result = ctx->abi->tes_patch_id_replaced ?
ctx->abi->tes_patch_id_replaced : ac_get_arg(&ctx->ac, ctx->args->tes_patch_id);
} else if (ctx->stage == MESA_SHADER_VERTEX) {
if (ctx->args->vs_prim_id.used)
result = ac_get_arg(&ctx->ac, ctx->args->vs_prim_id); /* legacy */
else
result = ac_get_arg(&ctx->ac, ctx->args->gs_prim_id); /* NGG */
} else
fprintf(stderr, "Unknown primitive id intrinsic: %d", ctx->stage);
break;
case nir_intrinsic_load_helper_invocation:
case nir_intrinsic_is_helper_invocation:
result = ac_build_load_helper_invocation(&ctx->ac);
@@ -2966,21 +2937,6 @@ static bool visit_intrinsic(struct ac_nir_context *ctx, nir_intrinsic_instr *ins
ac_build_sendmsg(&ctx->ac, imm, m0_content);
break;
}
case nir_intrinsic_load_tess_coord: {
LLVMValueRef coord[] = {
ctx->abi->tes_u_replaced ? ctx->abi->tes_u_replaced : ac_get_arg(&ctx->ac, ctx->args->tes_u),
ctx->abi->tes_v_replaced ? ctx->abi->tes_v_replaced : ac_get_arg(&ctx->ac, ctx->args->tes_v),
ctx->ac.f32_0,
};
/* For triangles, the vector should be (u, v, 1-u-v). */
if (ctx->info->tess._primitive_mode == TESS_PRIMITIVE_TRIANGLES) {
coord[2] = LLVMBuildFSub(ctx->ac.builder, ctx->ac.f32_1,
LLVMBuildFAdd(ctx->ac.builder, coord[0], coord[1], ""), "");
}
result = ac_build_gather_values(&ctx->ac, coord, 3);
break;
}
case nir_intrinsic_vote_all: {
result = ac_build_vote_all(&ctx->ac, get_src(ctx, instr->src[0]));
if (ctx->info->stage == MESA_SHADER_FRAGMENT)
@@ -3239,12 +3195,6 @@ static bool visit_intrinsic(struct ac_nir_context *ctx, nir_intrinsic_instr *ins
result = LLVMBuildICmp(ctx->ac.builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
break;
}
case nir_intrinsic_overwrite_tes_arguments_amd:
ctx->abi->tes_u_replaced = ac_to_float(&ctx->ac, get_src(ctx, instr->src[0]));
ctx->abi->tes_v_replaced = ac_to_float(&ctx->ac, get_src(ctx, instr->src[1]));
ctx->abi->tes_rel_patch_id_replaced = get_src(ctx, instr->src[3]);
ctx->abi->tes_patch_id_replaced = get_src(ctx, instr->src[2]);
break;
case nir_intrinsic_gds_atomic_add_amd: {
LLVMValueRef store_val = get_src(ctx, instr->src[0]);
LLVMValueRef addr = get_src(ctx, instr->src[1]);

View File

@@ -25,12 +25,6 @@ struct ac_shader_abi {
LLVMValueRef outputs[AC_LLVM_MAX_OUTPUTS * 4];
bool is_16bit[AC_LLVM_MAX_OUTPUTS * 4];
/* replaced registers when culling enabled */
LLVMValueRef tes_u_replaced;
LLVMValueRef tes_v_replaced;
LLVMValueRef tes_rel_patch_id_replaced;
LLVMValueRef tes_patch_id_replaced;
LLVMValueRef (*load_tess_varyings)(struct ac_shader_abi *abi, LLVMTypeRef type,
unsigned driver_location, unsigned component,
unsigned num_components);

View File

@@ -138,26 +138,6 @@ lower_abi_instr(nir_builder *b, nir_intrinsic_instr *intrin, void *state)
break;
}
case nir_intrinsic_load_tess_rel_patch_id_amd:
if (stage == MESA_SHADER_TESS_CTRL) {
replacement = nir_extract_u8(b, ac_nir_load_arg(b, &s->args->ac, s->args->ac.tcs_rel_ids), nir_imm_int(b, 0));
} else if (stage == MESA_SHADER_TESS_EVAL) {
/* Setting an upper bound like this will actually make it possible
* to optimize some multiplications (in address calculations) so that
* constant additions can be added to the const offset in memory load instructions.
*/
nir_def *arg = ac_nir_load_arg(b, &s->args->ac, s->args->ac.tes_rel_patch_id);
if (s->info->tes.tcs_vertices_out) {
nir_intrinsic_instr *load_arg = nir_instr_as_intrinsic(arg->parent_instr);
nir_intrinsic_set_arg_upper_bound_u32_amd(load_arg, 2048 / MAX2(s->info->tes.tcs_vertices_out, 1));
}
replacement = arg;
} else {
unreachable("invalid tessellation shader stage");
}
break;
case nir_intrinsic_load_patch_vertices_in:
if (stage == MESA_SHADER_TESS_CTRL) {
if (s->gfx_state->ts.patch_control_points) {

View File

@@ -715,18 +715,6 @@ static bool lower_intrinsic(nir_builder *b, nir_instr *instr, struct lower_abi_s
assert(s->esgs_ring);
replacement = s->esgs_ring;
break;
case nir_intrinsic_load_tess_rel_patch_id_amd:
/* LLVM need to replace patch id arg, so have to be done in LLVM backend. */
if (!b->shader->info.use_aco_amd)
return false;
if (stage == MESA_SHADER_TESS_CTRL) {
replacement = ac_nir_unpack_arg(b, &args->ac, args->ac.tcs_rel_ids, 0, 8);
} else {
assert(stage == MESA_SHADER_TESS_EVAL);
replacement = ac_nir_load_arg(b, &args->ac, args->ac.tes_rel_patch_id);
}
break;
case nir_intrinsic_load_ring_tess_offchip_amd:
assert(s->tess_offchip_ring);
replacement = s->tess_offchip_ring;