diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 20b31e915d3..0bf59e44d36 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -4801,6 +4801,7 @@ bool nir_repair_ssa(nir_shader *shader); void nir_convert_loop_to_lcssa(nir_loop *loop); bool nir_convert_to_lcssa(nir_shader *shader, bool skip_invariants, bool skip_bool_invariants); void nir_divergence_analysis(nir_shader *shader); +bool nir_update_instr_divergence(nir_shader *shader, nir_instr *instr); /* If phi_webs_only is true, only convert SSA values involved in phi nodes to * registers. If false, convert all values (even those not involved in a phi diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h index 2b340bac700..1b43892736a 100644 --- a/src/compiler/nir/nir_builder.h +++ b/src/compiler/nir/nir_builder.h @@ -36,6 +36,10 @@ typedef struct nir_builder { /* Whether new ALU instructions will be marked "exact" */ bool exact; + /* Whether to run divergence analysis on inserted instructions (loop merge + * and header phis are not updated). */ + bool update_divergence; + nir_shader *shader; nir_function_impl *impl; } nir_builder; @@ -54,6 +58,7 @@ nir_builder_init_simple_shader(nir_builder *build, void *mem_ctx, gl_shader_stage stage, const nir_shader_compiler_options *options) { + memset(build, 0, sizeof(*build)); build->shader = nir_shader_create(mem_ctx, stage, options, NULL); nir_function *func = nir_function_create(build->shader, "main"); func->is_entrypoint = true; @@ -110,6 +115,9 @@ nir_builder_instr_insert(nir_builder *build, nir_instr *instr) { nir_instr_insert(build->cursor, instr); + if (build->update_divergence) + nir_update_instr_divergence(build->shader, instr); + /* Move the cursor forward. */ build->cursor = nir_after_instr(instr); } @@ -237,6 +245,8 @@ nir_ssa_undef(nir_builder *build, unsigned num_components, unsigned bit_size) return NULL; nir_instr_insert(nir_before_cf_list(&build->impl->body), &undef->instr); + if (build->update_divergence) + nir_update_instr_divergence(build->shader, &undef->instr); return &undef->def; } diff --git a/src/compiler/nir/nir_divergence_analysis.c b/src/compiler/nir/nir_divergence_analysis.c index ccfc7ded08f..cd33996f5c9 100644 --- a/src/compiler/nir/nir_divergence_analysis.c +++ b/src/compiler/nir/nir_divergence_analysis.c @@ -614,6 +614,31 @@ set_ssa_def_not_divergent(nir_ssa_def *def, UNUSED void *_state) return true; } +static bool +update_instr_divergence(nir_shader *shader, nir_instr *instr) +{ + switch (instr->type) { + case nir_instr_type_alu: + return visit_alu(nir_instr_as_alu(instr)); + case nir_instr_type_intrinsic: + return visit_intrinsic(shader, nir_instr_as_intrinsic(instr)); + case nir_instr_type_tex: + return visit_tex(nir_instr_as_tex(instr)); + case nir_instr_type_load_const: + return visit_load_const(nir_instr_as_load_const(instr)); + case nir_instr_type_ssa_undef: + return visit_ssa_undef(nir_instr_as_ssa_undef(instr)); + case nir_instr_type_deref: + return visit_deref(shader, nir_instr_as_deref(instr)); + case nir_instr_type_jump: + case nir_instr_type_phi: + case nir_instr_type_call: + case nir_instr_type_parallel_copy: + default: + unreachable("NIR divergence analysis: Unsupported instruction type."); + } +} + static bool visit_block(nir_block *block, struct divergence_state *state) { @@ -627,33 +652,10 @@ visit_block(nir_block *block, struct divergence_state *state) if (state->first_visit) nir_foreach_ssa_def(instr, set_ssa_def_not_divergent, NULL); - switch (instr->type) { - case nir_instr_type_alu: - has_changed |= visit_alu(nir_instr_as_alu(instr)); - break; - case nir_instr_type_intrinsic: - has_changed |= visit_intrinsic(state->shader, nir_instr_as_intrinsic(instr)); - break; - case nir_instr_type_tex: - has_changed |= visit_tex(nir_instr_as_tex(instr)); - break; - case nir_instr_type_load_const: - has_changed |= visit_load_const(nir_instr_as_load_const(instr)); - break; - case nir_instr_type_ssa_undef: - has_changed |= visit_ssa_undef(nir_instr_as_ssa_undef(instr)); - break; - case nir_instr_type_deref: - has_changed |= visit_deref(state->shader, nir_instr_as_deref(instr)); - break; - case nir_instr_type_jump: + if (instr->type == nir_instr_type_jump) has_changed |= visit_jump(nir_instr_as_jump(instr), state); - break; - case nir_instr_type_phi: - case nir_instr_type_call: - case nir_instr_type_parallel_copy: - unreachable("NIR divergence analysis: Unsupported instruction type."); - } + else + has_changed |= update_instr_divergence(state->shader, instr); } return has_changed; @@ -903,3 +905,23 @@ nir_divergence_analysis(nir_shader *shader) visit_cf_list(&nir_shader_get_entrypoint(shader)->body, &state); } +bool nir_update_instr_divergence(nir_shader *shader, nir_instr *instr) +{ + nir_foreach_ssa_def(instr, set_ssa_def_not_divergent, NULL); + + if (instr->type == nir_instr_type_phi) { + nir_cf_node *prev = nir_cf_node_prev(&instr->block->cf_node); + /* can only update gamma/if phis */ + if (!prev || prev->type != nir_cf_node_if) + return false; + + nir_if *nif = nir_cf_node_as_if(prev); + + visit_if_merge_phi(nir_instr_as_phi(instr), nir_src_is_divergent(nif->condition)); + return true; + } + + update_instr_divergence(shader, instr); + return true; +} +