diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index cbf5c0746c3..1d1bb6c9584 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -4780,7 +4780,10 @@ bool nir_opt_trivial_continues(nir_shader *shader); bool nir_opt_undef(nir_shader *shader); -bool nir_opt_vectorize(nir_shader *shader); +typedef bool (*nir_opt_vectorize_cb)(const nir_instr *a, const nir_instr *b, + void *data); +bool nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter, + void *data); bool nir_opt_conditional_discard(nir_shader *shader); diff --git a/src/compiler/nir/nir_opt_vectorize.c b/src/compiler/nir/nir_opt_vectorize.c index 3bf65151e0d..1d372ead9fb 100644 --- a/src/compiler/nir/nir_opt_vectorize.c +++ b/src/compiler/nir/nir_opt_vectorize.c @@ -162,7 +162,8 @@ instr_can_rewrite(nir_instr *instr) */ static nir_instr * -instr_try_combine(struct nir_shader *nir, nir_instr *instr1, nir_instr *instr2) +instr_try_combine(struct nir_shader *nir, nir_instr *instr1, nir_instr *instr2, + nir_opt_vectorize_cb filter, void *data) { assert(instr1->type == nir_instr_type_alu); assert(instr2->type == nir_instr_type_alu); @@ -181,6 +182,9 @@ instr_try_combine(struct nir_shader *nir, nir_instr *instr1, nir_instr *instr2) (total_components > 2 || alu1->dest.dest.ssa.bit_size != 16)) return NULL; + if (filter && !filter(&alu1->instr, &alu2->instr, data)) + return NULL; + nir_builder b; nir_builder_init(&b, nir_cf_node_get_function(&instr1->block->cf_node)); b.cursor = nir_after_instr(instr1); @@ -320,13 +324,15 @@ vec_instr_stack_create(void *mem_ctx) static bool vec_instr_stack_push(struct nir_shader *nir, struct util_dynarray *stack, - nir_instr *instr) + nir_instr *instr, + nir_opt_vectorize_cb filter, void *data) { /* Walk the stack from child to parent to make live ranges shorter by * matching the closest thing we can */ util_dynarray_foreach_reverse(stack, nir_instr *, stack_instr) { - nir_instr *new_instr = instr_try_combine(nir, *stack_instr, instr); + nir_instr *new_instr = instr_try_combine(nir, *stack_instr, instr, + filter, data); if (new_instr) { *stack_instr = new_instr; return true; @@ -378,20 +384,21 @@ vec_instr_set_destroy(struct set *instr_set) static bool vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set, - nir_instr *instr) + nir_instr *instr, + nir_opt_vectorize_cb filter, void *data) { if (!instr_can_rewrite(instr)) return false; struct util_dynarray *new_stack = vec_instr_stack_create(instr_set); - vec_instr_stack_push(nir, new_stack, instr); + vec_instr_stack_push(nir, new_stack, instr, filter, data); struct set_entry *entry = _mesa_set_search(instr_set, new_stack); if (entry) { ralloc_free(new_stack); struct util_dynarray *stack = (struct util_dynarray *) entry->key; - return vec_instr_stack_push(nir, stack, instr); + return vec_instr_stack_push(nir, stack, instr, filter, data); } _mesa_set_add(instr_set, new_stack); @@ -400,7 +407,7 @@ vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set, static void vec_instr_set_remove(struct nir_shader *nir, struct set *instr_set, - nir_instr *instr) + nir_instr *instr, nir_opt_vectorize_cb filter, void *data) { if (!instr_can_rewrite(instr)) return; @@ -417,7 +424,7 @@ vec_instr_set_remove(struct nir_shader *nir, struct set *instr_set, * comparison function as well. */ struct util_dynarray *temp = vec_instr_stack_create(instr_set); - vec_instr_stack_push(nir, temp, instr); + vec_instr_stack_push(nir, temp, instr, filter, data); struct set_entry *entry = _mesa_set_search(instr_set, temp); ralloc_free(temp); @@ -433,34 +440,37 @@ vec_instr_set_remove(struct nir_shader *nir, struct set *instr_set, static bool vectorize_block(struct nir_shader *nir, nir_block *block, - struct set *instr_set) + struct set *instr_set, + nir_opt_vectorize_cb filter, void *data) { bool progress = false; nir_foreach_instr_safe(instr, block) { - if (vec_instr_set_add_or_rewrite(nir, instr_set, instr)) + if (vec_instr_set_add_or_rewrite(nir, instr_set, instr, filter, data)) progress = true; } for (unsigned i = 0; i < block->num_dom_children; i++) { nir_block *child = block->dom_children[i]; - progress |= vectorize_block(nir, child, instr_set); + progress |= vectorize_block(nir, child, instr_set, filter, data); } nir_foreach_instr_reverse(instr, block) - vec_instr_set_remove(nir, instr_set, instr); + vec_instr_set_remove(nir, instr_set, instr, filter, data); return progress; } static bool -nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl) +nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl, + nir_opt_vectorize_cb filter, void *data) { struct set *instr_set = vec_instr_set_create(); nir_metadata_require(impl, nir_metadata_dominance); - bool progress = vectorize_block(nir, nir_start_block(impl), instr_set); + bool progress = vectorize_block(nir, nir_start_block(impl), instr_set, + filter, data); if (progress) nir_metadata_preserve(impl, nir_metadata_block_index | @@ -471,13 +481,14 @@ nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl) } bool -nir_opt_vectorize(nir_shader *shader) +nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter, + void *data) { bool progress = false; nir_foreach_function(function, shader) { if (function->impl) - progress |= nir_opt_vectorize_impl(shader, function->impl); + progress |= nir_opt_vectorize_impl(shader, function->impl, filter, data); } return progress; diff --git a/src/gallium/drivers/etnaviv/etnaviv_compiler_nir.c b/src/gallium/drivers/etnaviv/etnaviv_compiler_nir.c index 79ef97cdbde..b09a720e43c 100644 --- a/src/gallium/drivers/etnaviv/etnaviv_compiler_nir.c +++ b/src/gallium/drivers/etnaviv/etnaviv_compiler_nir.c @@ -1122,7 +1122,7 @@ etna_compile_shader_nir(struct etna_shader_variant *v) if (DBG_ENABLED(ETNA_DBG_DUMP_SHADERS)) nir_print_shader(s, stdout); - while( OPT(s, nir_opt_vectorize) ); + while( OPT(s, nir_opt_vectorize, NULL, NULL) ); NIR_PASS_V(s, nir_lower_alu_to_scalar, etna_alu_to_scalar_filter_cb, specs); NIR_PASS_V(s, nir_remove_dead_variables, nir_var_function_temp, NULL); diff --git a/src/gallium/drivers/lima/lima_program.c b/src/gallium/drivers/lima/lima_program.c index 1bb1bd88622..87029d3140a 100644 --- a/src/gallium/drivers/lima/lima_program.c +++ b/src/gallium/drivers/lima/lima_program.c @@ -201,7 +201,7 @@ lima_program_optimize_fs_nir(struct nir_shader *s, do { progress = false; - NIR_PASS(progress, s, nir_opt_vectorize); + NIR_PASS(progress, s, nir_opt_vectorize, NULL, NULL); } while (progress); do { diff --git a/src/gallium/drivers/r600/r600_shader.c b/src/gallium/drivers/r600/r600_shader.c index b7dab54daab..9ac0efad7e4 100644 --- a/src/gallium/drivers/r600/r600_shader.c +++ b/src/gallium/drivers/r600/r600_shader.c @@ -201,7 +201,7 @@ int r600_pipe_shader_create(struct pipe_context *ctx, NIR_PASS_V(sel->nir, nir_lower_regs_to_ssa); NIR_PASS_V(sel->nir, nir_lower_alu_to_scalar, NULL, NULL); NIR_PASS_V(sel->nir, nir_lower_int64); - NIR_PASS_V(sel->nir, nir_opt_vectorize); + NIR_PASS_V(sel->nir, nir_opt_vectorize, NULL, NULL); } NIR_PASS_V(sel->nir, nir_lower_flrp, ~0, false, false); } diff --git a/src/gallium/drivers/r600/sfn/sfn_nir.cpp b/src/gallium/drivers/r600/sfn/sfn_nir.cpp index cbaaa7eb147..9f75726f2ac 100644 --- a/src/gallium/drivers/r600/sfn/sfn_nir.cpp +++ b/src/gallium/drivers/r600/sfn/sfn_nir.cpp @@ -768,7 +768,7 @@ optimize_once(nir_shader *shader) NIR_PASS(progress, shader, nir_opt_algebraic); NIR_PASS(progress, shader, nir_opt_constant_folding); NIR_PASS(progress, shader, nir_opt_copy_prop_vars); - NIR_PASS(progress, shader, nir_opt_vectorize); + NIR_PASS(progress, shader, nir_opt_vectorize, NULL, NULL); NIR_PASS(progress, shader, nir_opt_remove_phis); diff --git a/src/panfrost/midgard/midgard_compile.c b/src/panfrost/midgard/midgard_compile.c index a58701de75f..bd13efd0eca 100644 --- a/src/panfrost/midgard/midgard_compile.c +++ b/src/panfrost/midgard/midgard_compile.c @@ -558,7 +558,7 @@ optimise_nir(nir_shader *nir, unsigned quirks, bool is_blend) nir_var_shader_out | nir_var_function_temp); - NIR_PASS(progress, nir, nir_opt_vectorize); + NIR_PASS(progress, nir, nir_opt_vectorize, NULL, NULL); } while (progress); NIR_PASS_V(nir, nir_lower_alu_to_scalar, mdg_is_64, NULL);