diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index c8aa26608e3..7d2f80ce491 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -5527,6 +5527,7 @@ typedef struct nir_lower_subgroups_options { bool lower_to_scalar : 1; bool lower_vote_trivial : 1; bool lower_vote_eq : 1; + bool lower_first_invocation_to_ballot : 1; bool lower_subgroup_masks : 1; bool lower_relative_shuffle : 1; bool lower_shuffle_to_32bit : 1; diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index e679a7a7e9a..a7577bb5d84 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -549,6 +549,13 @@ lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin, return dst; } +static nir_def * +lower_first_invocation_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin, + const nir_lower_subgroups_options *options) +{ + return nir_ballot_find_lsb(b, 32, nir_ballot(b, 4, 32, nir_imm_true(b))); +} + static nir_def * lower_read_invocation_to_cond(nir_builder *b, nir_intrinsic_instr *intrin) { @@ -588,6 +595,15 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) return nir_imm_int(b, options->subgroup_size); break; + case nir_intrinsic_first_invocation: + if (options->subgroup_size == 1) + return nir_imm_int(b, 0); + + if (options->lower_first_invocation_to_ballot) + return lower_first_invocation_to_ballot(b, intrin, options); + + break; + case nir_intrinsic_read_invocation: if (options->lower_to_scalar && intrin->num_components > 1) return lower_subgroup_op_to_scalar(b, intrin);