diff --git a/src/asahi/compiler/agx_compile.c b/src/asahi/compiler/agx_compile.c index 3d1c995925c..a863bcda208 100644 --- a/src/asahi/compiler/agx_compile.c +++ b/src/asahi/compiler/agx_compile.c @@ -1717,7 +1717,9 @@ agx_compile_shader_nir(nir_shader *nir, agx_optimize_nir(nir); /* Implement conditional discard with real control flow like Metal */ - NIR_PASS_V(nir, nir_lower_discard_if); + NIR_PASS_V(nir, nir_lower_discard_if, (nir_lower_discard_if_to_cf | + nir_lower_demote_if_to_cf | + nir_lower_terminate_if_to_cf)); /* Must be last since NIR passes can remap driver_location freely */ if (ctx->stage == MESA_SHADER_VERTEX) diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index ed1486dec6f..fe0322faa79 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -5393,7 +5393,13 @@ typedef enum { bool nir_lower_interpolation(nir_shader *shader, nir_lower_interpolation_options options); -bool nir_lower_discard_if(nir_shader *shader); +typedef enum { + nir_lower_discard_if_to_cf = (1 << 0), + nir_lower_demote_if_to_cf = (1 << 1), + nir_lower_terminate_if_to_cf = (1 << 2), +} nir_lower_discard_if_options; + +bool nir_lower_discard_if(nir_shader *shader, nir_lower_discard_if_options options); bool nir_lower_discard_or_demote(nir_shader *shader, bool force_correct_quad_ops_after_discard); diff --git a/src/compiler/nir/nir_lower_discard_if.c b/src/compiler/nir/nir_lower_discard_if.c index ac34ad45ae2..b8de6426a04 100644 --- a/src/compiler/nir/nir_lower_discard_if.c +++ b/src/compiler/nir/nir_lower_discard_if.c @@ -25,26 +25,52 @@ #include "compiler/nir/nir_builder.h" static bool -lower_discard_if_instr(nir_builder *b, nir_instr *instr_, UNUSED void *cb_data) +lower_discard_if_instr(nir_builder *b, nir_instr *instr_, void *cb_data) { + nir_lower_discard_if_options options = *(nir_lower_discard_if_options *)cb_data; + if (instr_->type != nir_instr_type_intrinsic) return false; nir_intrinsic_instr *instr = nir_instr_as_intrinsic(instr_); - if (instr->intrinsic == nir_intrinsic_discard_if) { - b->cursor = nir_before_instr(&instr->instr); - - nir_if *if_stmt = nir_push_if(b, nir_ssa_for_src(b, instr->src[0], 1)); - nir_discard(b); - nir_pop_if(b, if_stmt); - nir_instr_remove(&instr->instr); - return true; - } else if (instr->intrinsic == nir_intrinsic_terminate_if || - instr->intrinsic == nir_intrinsic_demote_if) { - unreachable("todo: handle terminates and demotes for Vulkan"); + switch (instr->intrinsic) { + case nir_intrinsic_discard_if: + if (!(options & nir_lower_discard_if_to_cf)) + return false; + break; + case nir_intrinsic_demote_if: + if (!(options & nir_lower_demote_if_to_cf)) + return false; + break; + case nir_intrinsic_terminate_if: + if (!(options & nir_lower_terminate_if_to_cf)) + return false; + break; + default: + return false; } + b->cursor = nir_before_instr(&instr->instr); + + nir_if *if_stmt = nir_push_if(b, nir_ssa_for_src(b, instr->src[0], 1)); + switch (instr->intrinsic) { + case nir_intrinsic_discard_if: + nir_discard(b); + break; + case nir_intrinsic_demote_if: + nir_demote(b); + break; + case nir_intrinsic_terminate_if: + nir_terminate(b); + break; + default: + unreachable("bad intrinsic"); + } + nir_pop_if(b, if_stmt); + nir_instr_remove(&instr->instr); + return true; + /* a shader like this (shaders@glsl-fs-discard-04): uniform int j, k; @@ -94,10 +120,10 @@ lower_discard_if_instr(nir_builder *b, nir_instr *instr_, UNUSED void *cb_data) } bool -nir_lower_discard_if(nir_shader *shader) +nir_lower_discard_if(nir_shader *shader, nir_lower_discard_if_options options) { return nir_shader_instructions_pass(shader, lower_discard_if_instr, nir_metadata_none, - NULL); + &options); } diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index cc73c714243..872c9927256 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -3041,7 +3041,9 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir, optimize_nir(nir, NULL); NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_function_temp, NULL); - NIR_PASS_V(nir, nir_lower_discard_if); + NIR_PASS_V(nir, nir_lower_discard_if, (nir_lower_discard_if_to_cf | + nir_lower_demote_if_to_cf | + nir_lower_terminate_if_to_cf)); NIR_PASS_V(nir, nir_lower_fragcolor, nir->info.fs.color_is_dual_source ? 1 : 8); NIR_PASS_V(nir, lower_64bit_vertex_attribs);