diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 405298a00d5..34ff7e16bcc 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -25,6 +25,7 @@ #include "ac_nir.h" #include "nir_builder.h" #include "u_math.h" +#include "u_vector.h" enum { nggc_passflag_used_by_pos = 1, @@ -32,6 +33,12 @@ enum { nggc_passflag_used_by_both = nggc_passflag_used_by_pos | nggc_passflag_used_by_other, }; +typedef struct +{ + nir_ssa_def *ssa; + nir_variable *var; +} saved_uniform; + typedef struct { nir_variable *position_value_var; @@ -39,6 +46,8 @@ typedef struct nir_variable *es_accepted_var; nir_variable *gs_accepted_var; + struct u_vector saved_uniforms; + bool passthrough; bool export_prim_id; bool early_prim_export; @@ -717,6 +726,156 @@ analyze_shader_before_culling(nir_shader *shader, lower_ngg_nogs_state *nogs_sta } } +/** + * Save the reusable SSA definitions to variables so that the + * bottom shader part can reuse them from the top part. + * + * 1. We create a new function temporary variable for reusables, + * and insert a store+load. + * 2. The shader is cloned (the top part is created), then the + * control flow is reinserted (for the bottom part.) + * 3. For reusables, we delete the variable stores from the + * bottom part. This will make them use the variables from + * the top part and DCE the redundant instructions. + */ +static void +save_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state) +{ + ASSERTED int vec_ok = u_vector_init(&nogs_state->saved_uniforms, sizeof(saved_uniform), 4 * sizeof(saved_uniform)); + assert(vec_ok); + + unsigned loop_depth = 0; + + nir_foreach_block_safe(block, b->impl) { + /* Check whether we're in a loop. */ + nir_cf_node *next_cf_node = nir_cf_node_next(&block->cf_node); + nir_cf_node *prev_cf_node = nir_cf_node_prev(&block->cf_node); + if (next_cf_node && next_cf_node->type == nir_cf_node_loop) + loop_depth++; + if (prev_cf_node && prev_cf_node->type == nir_cf_node_loop) + loop_depth--; + + /* The following code doesn't make sense in loops, so just skip it then. */ + if (loop_depth) + continue; + + nir_foreach_instr_safe(instr, block) { + /* Find instructions whose SSA definitions are used by both + * the top and bottom parts of the shader. In this case, it + * makes sense to try to reuse these from the top part. + */ + if ((instr->pass_flags & nggc_passflag_used_by_both) != nggc_passflag_used_by_both) + continue; + + nir_ssa_def *ssa = NULL; + + switch (instr->type) { + case nir_instr_type_alu: { + nir_alu_instr *alu = nir_instr_as_alu(instr); + if (alu->dest.dest.ssa.divergent) + continue; + /* Ignore uniform floats because they regress VGPR usage too much */ + if (nir_op_infos[alu->op].output_type & nir_type_float) + continue; + ssa = &alu->dest.dest.ssa; + break; + } + case nir_instr_type_intrinsic: { + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + if (!nir_intrinsic_can_reorder(intrin) || + !nir_intrinsic_infos[intrin->intrinsic].has_dest || + intrin->dest.ssa.divergent) + continue; + ssa = &intrin->dest.ssa; + break; + } + case nir_instr_type_phi: { + nir_phi_instr *phi = nir_instr_as_phi(instr); + if (phi->dest.ssa.divergent) + continue; + ssa = &phi->dest.ssa; + break; + } + default: + continue; + } + + assert(ssa); + + enum glsl_base_type base_type = GLSL_TYPE_UINT; + switch (ssa->bit_size) { + case 8: base_type = GLSL_TYPE_UINT8; break; + case 16: base_type = GLSL_TYPE_UINT16; break; + case 32: base_type = GLSL_TYPE_UINT; break; + case 64: base_type = GLSL_TYPE_UINT64; break; + default: continue; + } + + const struct glsl_type *t = ssa->num_components == 1 + ? glsl_scalar_type(base_type) + : glsl_vector_type(base_type, ssa->num_components); + + saved_uniform *saved = (saved_uniform *) u_vector_add(&nogs_state->saved_uniforms); + assert(saved); + + saved->var = nir_local_variable_create(b->impl, t, NULL); + saved->ssa = ssa; + + b->cursor = instr->type == nir_instr_type_phi + ? nir_after_instr_and_phis(instr) + : nir_after_instr(instr); + nir_store_var(b, saved->var, saved->ssa, BITFIELD_MASK(ssa->num_components)); + nir_ssa_def *reloaded = nir_load_var(b, saved->var); + nir_ssa_def_rewrite_uses_after(ssa, reloaded, reloaded->parent_instr); + } + } +} + +/** + * Reuses suitable variables from the top part of the shader, + * by deleting their stores from the bottom part. + */ +static void +apply_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state) +{ + if (!u_vector_length(&nogs_state->saved_uniforms)) { + u_vector_finish(&nogs_state->saved_uniforms); + return; + } + + nir_foreach_block_reverse_safe(block, b->impl) { + nir_foreach_instr_reverse_safe(instr, block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + + /* When we found any of these intrinsics, it means + * we reached the top part and we must stop. + */ + if (intrin->intrinsic == nir_intrinsic_overwrite_subgroup_num_vertices_and_primitives_amd || + intrin->intrinsic == nir_intrinsic_alloc_vertices_and_primitives_amd || + intrin->intrinsic == nir_intrinsic_export_primitive_amd) + goto done; + + if (intrin->intrinsic != nir_intrinsic_store_deref) + continue; + nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]); + if (deref->deref_type != nir_deref_type_var) + continue; + + saved_uniform *saved; + u_vector_foreach(saved, &nogs_state->saved_uniforms) { + if (saved->var == deref->var) { + nir_instr_remove(instr); + } + } + } + } + + done: + u_vector_finish(&nogs_state->saved_uniforms); +} + static void add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_cf, lower_ngg_nogs_state *nogs_state) { @@ -1025,6 +1184,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader, /* We need divergence info for culling shaders. */ nir_divergence_analysis(shader); analyze_shader_before_culling(shader, &state); + save_reusable_variables(b, &state); } nir_cf_list extracted; @@ -1082,6 +1242,9 @@ ac_nir_lower_ngg_nogs(nir_shader *shader, } if (can_cull) { + /* Replace uniforms. */ + apply_reusable_variables(b, &state); + /* Remove the redundant position output. */ remove_extra_pos_outputs(shader, &state);