nir: add nir_intrinsic_load_per_primitive_input, split from io_semantics flag

Instead of having 1 bit in nir_io_semantics indicating a per-primitive
FS input, add a dedicated intrinsic for it.

Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29895>
This commit is contained in:
Marek Olšák
2024-07-06 04:24:31 -04:00
committed by Marge Bot
parent ecfefe823e
commit b2d32ae246
29 changed files with 72 additions and 45 deletions

View File

@@ -234,7 +234,8 @@ can_move_coord(nir_scalar scalar, coord_info *info)
return false;
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(scalar.def->parent_instr);
if (intrin->intrinsic == nir_intrinsic_load_input) {
if (intrin->intrinsic == nir_intrinsic_load_input ||
intrin->intrinsic == nir_intrinsic_load_per_primitive_input) {
info->bary = NULL;
info->load = intrin;
return true;

View File

@@ -8410,6 +8410,7 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
case nir_intrinsic_load_interpolated_input: visit_load_interpolated_input(ctx, instr); break;
case nir_intrinsic_store_output: visit_store_output(ctx, instr); break;
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_input_vertex:
if (ctx->program->stage == fragment_fs)
visit_load_fs_input(ctx, instr);

View File

@@ -458,6 +458,7 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_intrinsic_unit_test_uniform_amd: type = RegType::sgpr; break;
case nir_intrinsic_load_sample_id:
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_output:
case nir_intrinsic_load_input_vertex:
case nir_intrinsic_load_per_vertex_input:

View File

@@ -3072,6 +3072,7 @@ static bool visit_intrinsic(struct ac_nir_context *ctx, nir_intrinsic_instr *ins
result = visit_get_ssbo_size(ctx, instr);
break;
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_input_vertex:
case nir_intrinsic_load_per_vertex_input:
result = visit_load(ctx, instr);

View File

@@ -294,6 +294,7 @@ gather_intrinsic_info(const nir_shader *nir, const nir_intrinsic_instr *instr, s
break;
}
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_interpolated_input:
case nir_intrinsic_load_input_vertex:
gather_intrinsic_load_input_info(nir, instr, info, gfx_state, stage_key);

View File

@@ -3008,6 +3008,7 @@ nir_intrinsic_instr_dest_type(const nir_intrinsic_instr *intrin)
}
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_uniform:
return nir_intrinsic_dest_type(intrin);

View File

@@ -2017,9 +2017,7 @@ typedef struct nir_io_semantics {
unsigned no_sysval_output : 1; /* whether this system value output has no
effect due to current pipeline states */
unsigned interp_explicit_strict : 1; /* preserve original vertex order */
unsigned per_primitive : 1; /* Per-primitive FS input (when FS is used with a mesh shader).
Note that per-primitive MS outputs are implied by
using a dedicated intrinsic, store_per_primitive_output. */
unsigned _pad : 1;
} nir_io_semantics;
/* Transform feedback info for 2 outputs. nir_intrinsic_store_output contains

View File

@@ -289,6 +289,7 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
is_divergent = !(options & nir_divergence_single_frag_shading_rate_per_subgroup);
break;
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
is_divergent = instr->src[0].ssa->divergent;
if (stage == MESA_SHADER_FRAGMENT) {

View File

@@ -539,6 +539,7 @@ gather_intrinsic_info(nir_intrinsic_instr *instr, nir_shader *shader,
case nir_intrinsic_load_per_vertex_input:
case nir_intrinsic_load_input_vertex:
case nir_intrinsic_load_interpolated_input:
case nir_intrinsic_load_per_primitive_input:
if (shader->info.stage == MESA_SHADER_TESS_EVAL &&
instr->intrinsic == nir_intrinsic_load_input &&
!is_patch_special) {
@@ -549,7 +550,7 @@ gather_intrinsic_info(nir_intrinsic_instr *instr, nir_shader *shader,
shader->info.inputs_read |= slot_mask;
if (nir_intrinsic_io_semantics(instr).high_dvec2)
shader->info.dual_slot_inputs |= slot_mask;
if (nir_intrinsic_io_semantics(instr).per_primitive)
if (instr->intrinsic == nir_intrinsic_load_per_primitive_input)
shader->info.per_primitive_inputs |= slot_mask;
shader->info.inputs_read_16bit |= slot_mask_16bit;
if (!nir_src_is_const(*nir_get_io_offset_src(instr))) {

View File

@@ -1137,6 +1137,8 @@ load("input_vertex", [1, 1], [BASE, COMPONENT, DEST_TYPE, IO_SEMANTICS], [CAN_EL
load("per_vertex_input", [1, 1], [BASE, RANGE, COMPONENT, DEST_TYPE, IO_SEMANTICS], [CAN_ELIMINATE, CAN_REORDER])
# src[] = { barycoord, offset }.
load("interpolated_input", [2, 1], [BASE, COMPONENT, DEST_TYPE, IO_SEMANTICS], [CAN_ELIMINATE, CAN_REORDER])
# src[] = { offset }.
load("per_primitive_input", [1], [BASE, COMPONENT, DEST_TYPE, IO_SEMANTICS], [CAN_ELIMINATE, CAN_REORDER])
# src[] = { buffer_index, offset }.
load("ssbo", [-1, 1], [ACCESS, ALIGN_MUL, ALIGN_OFFSET], [CAN_ELIMINATE])

View File

@@ -320,7 +320,12 @@ emit_load(struct lower_io_state *state,
op = nir_intrinsic_load_interpolated_input;
}
} else {
op = array_index ? nir_intrinsic_load_per_vertex_input : nir_intrinsic_load_input;
if (var->data.per_primitive)
op = nir_intrinsic_load_per_primitive_input;
else if (array_index)
op = nir_intrinsic_load_per_vertex_input;
else
op = nir_intrinsic_load_input;
}
break;
case nir_var_shader_out:
@@ -363,7 +368,6 @@ emit_load(struct lower_io_state *state,
semantics.fb_fetch_output = var->data.fb_fetch_output;
semantics.medium_precision = is_medium_precision(b->shader, var);
semantics.high_dvec2 = high_dvec2;
semantics.per_primitive = var->data.per_primitive;
/* "per_vertex" is misnamed. It means "explicit interpolation with
* the original vertex order", which is a stricter version of
* INTERP_MODE_EXPLICIT.
@@ -2741,6 +2745,7 @@ nir_get_io_offset_src_number(const nir_intrinsic_instr *instr)
{
switch (instr->intrinsic) {
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_output:
case nir_intrinsic_load_shared:
case nir_intrinsic_load_task_payload:
@@ -2941,6 +2946,7 @@ static bool
is_input(nir_intrinsic_instr *intrin)
{
return intrin->intrinsic == nir_intrinsic_load_input ||
intrin->intrinsic == nir_intrinsic_load_per_primitive_input ||
intrin->intrinsic == nir_intrinsic_load_input_vertex ||
intrin->intrinsic == nir_intrinsic_load_per_vertex_input ||
intrin->intrinsic == nir_intrinsic_load_interpolated_input ||

View File

@@ -278,6 +278,7 @@ nir_lower_io_to_scalar_instr(nir_builder *b, nir_instr *instr, void *data)
return false;
if ((intr->intrinsic == nir_intrinsic_load_input ||
intr->intrinsic == nir_intrinsic_load_per_primitive_input ||
intr->intrinsic == nir_intrinsic_load_per_vertex_input ||
intr->intrinsic == nir_intrinsic_load_interpolated_input ||
intr->intrinsic == nir_intrinsic_load_input_vertex) &&

View File

@@ -39,6 +39,7 @@ get_io_intrinsic(nir_instr *instr, nir_variable_mode modes,
switch (intr->intrinsic) {
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_input_vertex:
case nir_intrinsic_load_interpolated_input:
case nir_intrinsic_load_per_vertex_input:
@@ -90,7 +91,7 @@ nir_recompute_io_bases(nir_shader *nir, nir_variable_mode modes)
if (mode == nir_var_shader_in) {
for (unsigned i = 0; i < num_slots; i++) {
if (sem.per_primitive)
if (intr->intrinsic == nir_intrinsic_load_per_primitive_input)
BITSET_SET(per_prim_inputs, sem.location + i);
else
BITSET_SET(inputs, sem.location + i);
@@ -123,7 +124,7 @@ nir_recompute_io_bases(nir_shader *nir, nir_variable_mode modes)
num_slots = (num_slots + sem.high_16bits + 1) / 2;
if (mode == nir_var_shader_in) {
if (sem.per_primitive) {
if (intr->intrinsic == nir_intrinsic_load_per_primitive_input) {
nir_intrinsic_set_base(intr,
num_normal_inputs +
BITSET_PREFIX_SUM(per_prim_inputs, sem.location));

View File

@@ -101,6 +101,7 @@ is_phi_src_scalarizable(nir_phi_src *src,
case nir_intrinsic_load_global:
case nir_intrinsic_load_global_constant:
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
return true;
default:
break;

View File

@@ -147,6 +147,7 @@ lower_system_value_instr(nir_builder *b, nir_instr *instr, void *_state)
}
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
if (b->shader->options->lower_layer_fs_input_to_sysval &&
b->shader->info.stage == MESA_SHADER_FRAGMENT &&
nir_intrinsic_io_semantics(intrin).location == VARYING_SLOT_LAYER)

View File

@@ -211,6 +211,7 @@ is_src_scalarizable(nir_src *src)
case nir_intrinsic_load_global:
case nir_intrinsic_load_global_constant:
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
return true;
default:
break;

View File

@@ -312,6 +312,7 @@ opt_shrink_vectors_intrinsic(nir_builder *b, nir_intrinsic_instr *instr,
case nir_intrinsic_load_uniform:
case nir_intrinsic_load_ubo:
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_input_vertex:
case nir_intrinsic_load_per_vertex_input:
case nir_intrinsic_load_interpolated_input:

View File

@@ -105,6 +105,7 @@ nir_can_move_instr(nir_instr *instr, nir_move_options options)
case nir_intrinsic_load_ssbo:
return (options & nir_move_load_ssbo) && nir_intrinsic_can_reorder(intrin);
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_interpolated_input:
case nir_intrinsic_load_per_vertex_input:
case nir_intrinsic_load_frag_coord:

View File

@@ -1090,6 +1090,7 @@ gather_inputs(struct nir_builder *builder, nir_intrinsic_instr *intr, void *cb_d
if (intr->intrinsic != nir_intrinsic_load_input &&
intr->intrinsic != nir_intrinsic_load_per_vertex_input &&
intr->intrinsic != nir_intrinsic_load_per_primitive_input &&
intr->intrinsic != nir_intrinsic_load_interpolated_input &&
intr->intrinsic != nir_intrinsic_load_input_vertex)
return false;
@@ -1127,10 +1128,10 @@ gather_inputs(struct nir_builder *builder, nir_intrinsic_instr *intr, void *cb_d
if (linkage->consumer_stage == MESA_SHADER_FRAGMENT) {
switch (intr->intrinsic) {
case nir_intrinsic_load_input:
if (sem.per_primitive)
fs_vec4_type = FS_VEC4_TYPE_PER_PRIMITIVE;
else
fs_vec4_type = FS_VEC4_TYPE_FLAT;
fs_vec4_type = FS_VEC4_TYPE_FLAT;
break;
case nir_intrinsic_load_per_primitive_input:
fs_vec4_type = FS_VEC4_TYPE_PER_PRIMITIVE;
break;
case nir_intrinsic_load_input_vertex:
if (sem.interp_explicit_strict)
@@ -1176,19 +1177,20 @@ gather_inputs(struct nir_builder *builder, nir_intrinsic_instr *intr, void *cb_d
if (linkage->consumer_stage == MESA_SHADER_FRAGMENT) {
switch (intr->intrinsic) {
case nir_intrinsic_load_input:
if (intr->def.bit_size == 32) {
if (sem.per_primitive)
BITSET_SET(linkage->per_primitive32_mask, slot);
else
BITSET_SET(linkage->flat32_mask, slot);
} else if (intr->def.bit_size == 16) {
if (sem.per_primitive)
BITSET_SET(linkage->per_primitive16_mask, slot);
else
BITSET_SET(linkage->flat16_mask, slot);
} else {
if (intr->def.bit_size == 32)
BITSET_SET(linkage->flat32_mask, slot);
else if (intr->def.bit_size == 16)
BITSET_SET(linkage->flat16_mask, slot);
else
unreachable("invalid load_input type");
break;
case nir_intrinsic_load_per_primitive_input:
if (intr->def.bit_size == 32)
BITSET_SET(linkage->per_primitive32_mask, slot);
else if (intr->def.bit_size == 16)
BITSET_SET(linkage->per_primitive16_mask, slot);
else
unreachable("invalid load_input type");
}
break;
case nir_intrinsic_load_input_vertex:
if (sem.interp_explicit_strict) {
@@ -2009,6 +2011,7 @@ clone_ssa(struct linkage_info *linkage, nir_builder *b, nir_def *ssa)
}
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_interpolated_input: {
/* We are cloning load_input in the producer for backward
* inter-shader code motion. Replace the input load with the stored
@@ -2263,11 +2266,11 @@ get_input_qualifier(struct linkage_info *linkage, unsigned i)
nir_intrinsic_instr *load =
list_first_entry(&slot->consumer.loads, struct list_node, head)->instr;
if (load->intrinsic == nir_intrinsic_load_input) {
if (nir_intrinsic_io_semantics(load).per_primitive)
return QUAL_PER_PRIMITIVE;
if (load->intrinsic == nir_intrinsic_load_input)
return is_color ? QUAL_COLOR_FLAT : QUAL_VAR_FLAT;
}
if (load->intrinsic == nir_intrinsic_load_per_primitive_input)
return QUAL_PER_PRIMITIVE;
if (load->intrinsic == nir_intrinsic_load_input_vertex) {
return nir_intrinsic_io_semantics(load).interp_explicit_strict ?
@@ -3479,10 +3482,9 @@ backward_inter_shader_code_motion(struct linkage_info *linkage,
load->instr.pass_flags |= FLAG_INTERP_FLAT;
}
break;
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_input_vertex:
/* Inter-shader code motion is unimplemented for explicit
* interpolation.
*/
/* Inter-shader code motion is unimplemented these. */
continue;
default:
unreachable("unexpected load intrinsic");
@@ -3704,12 +3706,11 @@ relocate_slot(struct linkage_info *linkage, struct scalar_slot *slot,
if (fs_vec4_type == FS_VEC4_TYPE_PER_PRIMITIVE) {
assert(intr->intrinsic == nir_intrinsic_store_per_primitive_output ||
intr->intrinsic == nir_intrinsic_load_per_primitive_output ||
intr->intrinsic == nir_intrinsic_load_input);
assert(intr->intrinsic != nir_intrinsic_load_input || sem.per_primitive);
intr->intrinsic == nir_intrinsic_load_per_primitive_input);
} else {
assert(!sem.per_primitive);
assert(intr->intrinsic != nir_intrinsic_store_per_primitive_output &&
intr->intrinsic != nir_intrinsic_load_per_primitive_output);
intr->intrinsic != nir_intrinsic_load_per_primitive_output &&
intr->intrinsic != nir_intrinsic_load_per_primitive_input);
}
/* This path is used when promoting convergent interpolated

View File

@@ -69,9 +69,6 @@ compare_is_not_vectorizable(nir_intrinsic_instr *a, nir_intrinsic_instr *b)
if (sem0.interp_explicit_strict != sem1.interp_explicit_strict)
return sem0.interp_explicit_strict > sem1.interp_explicit_strict ? 1 : -1;
if (sem0.per_primitive != sem1.per_primitive)
return sem0.per_primitive > sem1.per_primitive ? 1 : -1;
/* Only load_interpolated_input can't merge low and high halves of 16-bit
* loads/stores.
*/
@@ -488,6 +485,7 @@ nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
switch (intr->intrinsic) {
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_input_vertex:
case nir_intrinsic_load_interpolated_input:
case nir_intrinsic_load_per_vertex_input:

View File

@@ -1344,6 +1344,7 @@ print_intrinsic_instr(nir_intrinsic_instr *instr, print_state *state)
nir_variable_mode mode = nir_var_mem_generic;
switch (instr->intrinsic) {
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_interpolated_input:
case nir_intrinsic_load_per_vertex_input:
case nir_intrinsic_load_input_vertex:
@@ -1370,9 +1371,6 @@ print_intrinsic_instr(nir_intrinsic_instr *instr, print_state *state)
fprintf(fp, "io location=%s slots=%u", loc, io.num_slots);
if (io.per_primitive)
fprintf(fp, " per_primitive");
if (io.interp_explicit_strict)
fprintf(fp, " explicit_strict");
@@ -1622,6 +1620,7 @@ print_intrinsic_instr(nir_intrinsic_instr *instr, print_state *state)
var_mode = nir_var_uniform;
break;
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_interpolated_input:
case nir_intrinsic_load_per_vertex_input:
var_mode = nir_var_shader_in;

View File

@@ -380,6 +380,7 @@ nir_schedule_intrinsic_deps(nir_deps_state *state,
break;
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_per_vertex_input:
add_read_dep(state, state->load_input, n);
break;

View File

@@ -613,6 +613,7 @@ validate_intrinsic_instr(nir_intrinsic_instr *instr, validate_state *state)
case nir_intrinsic_load_uniform:
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
case nir_intrinsic_load_per_vertex_input:
case nir_intrinsic_load_interpolated_input:
case nir_intrinsic_load_output:

View File

@@ -2123,6 +2123,7 @@ visit_intrinsic(struct lp_build_nir_context *bld_base,
visit_store_reg(bld_base, instr);
break;
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
visit_load_input(bld_base, instr, result);
break;
case nir_intrinsic_store_output:

View File

@@ -4209,7 +4209,8 @@ fs_nir_emit_fs_intrinsic(nir_to_brw_state &ntb,
break;
}
case nir_intrinsic_load_input: {
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input: {
/* In Fragment Shaders load_input is used either for flat inputs or
* per-primitive inputs.
*/

View File

@@ -275,6 +275,7 @@ static bool
is_input(nir_intrinsic_instr *intrin)
{
return intrin->intrinsic == nir_intrinsic_load_input ||
intrin->intrinsic == nir_intrinsic_load_per_primitive_input ||
intrin->intrinsic == nir_intrinsic_load_per_vertex_input ||
intrin->intrinsic == nir_intrinsic_load_interpolated_input;
}

View File

@@ -3834,7 +3834,8 @@ fs_nir_emit_fs_intrinsic(nir_to_elk_state &ntb,
break;
}
case nir_intrinsic_load_input: {
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input: {
/* In Fragment Shaders load_input is used either for flat inputs or
* per-primitive inputs.
*/

View File

@@ -168,6 +168,7 @@ static bool
is_input(nir_intrinsic_instr *intrin)
{
return intrin->intrinsic == nir_intrinsic_load_input ||
intrin->intrinsic == nir_intrinsic_load_per_primitive_input ||
intrin->intrinsic == nir_intrinsic_load_per_vertex_input ||
intrin->intrinsic == nir_intrinsic_load_interpolated_input;
}

View File

@@ -423,7 +423,8 @@ vec4_visitor::nir_emit_intrinsic(nir_intrinsic_instr *instr)
/* Nothing to do with these. */
break;
case nir_intrinsic_load_input: {
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input: {
assert(instr->def.bit_size == 32);
/* We set EmitNoIndirectInput for VS */
unsigned load_offset = nir_src_as_uint(instr->src[0]);