diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c index d02e2e36c14..5cf91b0a6e7 100644 --- a/src/amd/vulkan/radv_pipeline.c +++ b/src/amd/vulkan/radv_pipeline.c @@ -4042,14 +4042,16 @@ lower_bit_size_callback(const nir_instr *instr, void *_) return 0; } -static bool -opt_vectorize_callback(const nir_instr *instr, void *_) +static uint8_t +opt_vectorize_callback(const nir_instr *instr, const void *_) { - assert(instr->type == nir_instr_type_alu); - nir_alu_instr *alu = nir_instr_as_alu(instr); - unsigned bit_size = alu->dest.dest.ssa.bit_size; + if (instr->type != nir_instr_type_alu) + return 0; + + const nir_alu_instr *alu = nir_instr_as_alu(instr); + const unsigned bit_size = alu->dest.dest.ssa.bit_size; if (bit_size != 16) - return false; + return 1; switch (alu->op) { case nir_op_fadd: @@ -4069,12 +4071,12 @@ opt_vectorize_callback(const nir_instr *instr, void *_) case nir_op_imax: case nir_op_umin: case nir_op_umax: - return true; + return 2; case nir_op_ishl: /* TODO: in NIR, these have 32bit shift operands */ case nir_op_ishr: /* while Radeon needs 16bit operands when vectorized */ case nir_op_ushr: default: - return false; + return 1; } } diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index a3e7485ff30..3ab0696a5a5 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -3228,6 +3228,15 @@ typedef enum { */ typedef bool (*nir_instr_filter_cb)(const nir_instr *, const void *); +/** A vectorization width callback + * + * Returns the maximum vectorization width per instruction. + * 0, if the instruction must not be modified. + * + * The vectorization width must be a power of 2. + */ +typedef uint8_t (*nir_vectorize_cb)(const nir_instr *, const void *); + typedef struct nir_shader_compiler_options { bool lower_fdiv; bool lower_ffma16; @@ -3455,7 +3464,11 @@ typedef struct nir_shader_compiler_options { nir_instr_filter_cb lower_to_scalar_filter; /** - * Whether nir_opt_vectorize should only create 16-bit 2D vectors. + * Disables potentially harmful algebraic transformations for architectures + * with SIMD-within-a-register semantics. + * + * Note, to actually vectorize 16bit instructions, use nir_opt_vectorize() + * with a suitable callback function. */ bool vectorize_vec2_16bit; @@ -5485,9 +5498,7 @@ bool nir_lower_undef_to_zero(nir_shader *shader); bool nir_opt_uniform_atomics(nir_shader *shader); -typedef bool (*nir_opt_vectorize_cb)(const nir_instr *instr, void *data); - -bool nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter, +bool nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter, void *data); bool nir_opt_conditional_discard(nir_shader *shader); diff --git a/src/compiler/nir/nir_opt_vectorize.c b/src/compiler/nir/nir_opt_vectorize.c index 83c841ee63e..dc6e1d84b52 100644 --- a/src/compiler/nir/nir_opt_vectorize.c +++ b/src/compiler/nir/nir_opt_vectorize.c @@ -22,6 +22,16 @@ * */ +/** + * nir_opt_vectorize() aims to vectorize ALU instructions. + * + * The default vectorization width is 4. + * If desired, a callback function which returns the max vectorization width + * per instruction can be provided. + * + * The max vectorization width must be a power of 2. + */ + #include "nir.h" #include "nir_vla.h" #include "nir_builder.h" @@ -125,7 +135,7 @@ instrs_equal(const void *data1, const void *data2) } static bool -instr_can_rewrite(nir_instr *instr, bool vectorize_16bit) +instr_can_rewrite(nir_instr *instr) { switch (instr->type) { case nir_instr_type_alu: { @@ -139,12 +149,7 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit) return false; /* no need to hash instructions which are already vectorized */ - if (alu->dest.dest.ssa.num_components >= 4) - return false; - - if (vectorize_16bit && - (alu->dest.dest.ssa.num_components >= 2 || - alu->dest.dest.ssa.bit_size != 16)) + if (alu->dest.dest.ssa.num_components >= instr->pass_flags) return false; if (nir_op_infos[alu->op].output_size != 0) @@ -156,8 +161,8 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit) /* don't hash instructions which are already swizzled * outside of max_components: these should better be scalarized */ - uint32_t mask = vectorize_16bit ? ~1 : ~3; - for (unsigned j = 0; j < alu->dest.dest.ssa.num_components; j++) { + uint32_t mask = ~(instr->pass_flags - 1); + for (unsigned j = 1; j < alu->dest.dest.ssa.num_components; j++) { if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[j] & mask)) return false; } @@ -179,10 +184,8 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit) * the same instructions into one vectorized instruction. Note that instr1 * should dominate instr2. */ - static nir_instr * -instr_try_combine(struct nir_shader *nir, struct set *instr_set, - nir_instr *instr1, nir_instr *instr2) +instr_try_combine(struct set *instr_set, nir_instr *instr1, nir_instr *instr2) { assert(instr1->type == nir_instr_type_alu); assert(instr2->type == nir_instr_type_alu); @@ -194,14 +197,10 @@ instr_try_combine(struct nir_shader *nir, struct set *instr_set, unsigned alu2_components = alu2->dest.dest.ssa.num_components; unsigned total_components = alu1_components + alu2_components; - if (total_components > 4) + assert(instr1->pass_flags == instr2->pass_flags); + if (total_components > instr1->pass_flags) return NULL; - if (nir->options->vectorize_vec2_16bit) { - assert(total_components == 2); - assert(alu1->dest.dest.ssa.bit_size == 16); - } - nir_builder b; nir_builder_init(&b, nir_cf_node_get_function(&instr1->block->cf_node)); b.cursor = nir_after_instr(instr1); @@ -352,28 +351,23 @@ vec_instr_set_destroy(struct set *instr_set) } static bool -vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set, - nir_instr *instr, - nir_opt_vectorize_cb filter, void *data) +vec_instr_set_add_or_rewrite(struct set *instr_set, nir_instr *instr, + nir_vectorize_cb filter, void *data) { - if (!instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit)) - return false; - - if (filter && !filter(instr, data)) - return false; - /* set max vector to instr pass flags: this is used to hash swizzles */ - instr->pass_flags = nir->options->vectorize_vec2_16bit ? 2 : 4; + instr->pass_flags = filter ? filter(instr, data) : 4; + assert(util_is_power_of_two_or_zero(instr->pass_flags)); + + if (!instr_can_rewrite(instr)) + return false; struct set_entry *entry = _mesa_set_search(instr_set, instr); if (entry) { nir_instr *old_instr = (nir_instr *) entry->key; _mesa_set_remove(instr_set, entry); - nir_instr *new_instr = instr_try_combine(nir, instr_set, - old_instr, instr); + nir_instr *new_instr = instr_try_combine(instr_set, old_instr, instr); if (new_instr) { - if (instr_can_rewrite(new_instr, nir->options->vectorize_vec2_16bit) && - (!filter || filter(new_instr, data))) + if (instr_can_rewrite(new_instr)) _mesa_set_add(instr_set, new_instr); return true; } @@ -384,25 +378,23 @@ vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set, } static bool -vectorize_block(struct nir_shader *nir, nir_block *block, - struct set *instr_set, - nir_opt_vectorize_cb filter, void *data) +vectorize_block(nir_block *block, struct set *instr_set, + nir_vectorize_cb filter, void *data) { bool progress = false; nir_foreach_instr_safe(instr, block) { - if (vec_instr_set_add_or_rewrite(nir, instr_set, instr, filter, data)) + if (vec_instr_set_add_or_rewrite(instr_set, instr, filter, data)) progress = true; } for (unsigned i = 0; i < block->num_dom_children; i++) { nir_block *child = block->dom_children[i]; - progress |= vectorize_block(nir, child, instr_set, filter, data); + progress |= vectorize_block(child, instr_set, filter, data); } nir_foreach_instr_reverse(instr, block) { - if (instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit) && - (!filter || filter(instr, data))) + if (instr_can_rewrite(instr)) _mesa_set_remove_key(instr_set, instr); } @@ -410,14 +402,14 @@ vectorize_block(struct nir_shader *nir, nir_block *block, } static bool -nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl, - nir_opt_vectorize_cb filter, void *data) +nir_opt_vectorize_impl(nir_function_impl *impl, + nir_vectorize_cb filter, void *data) { struct set *instr_set = vec_instr_set_create(); nir_metadata_require(impl, nir_metadata_dominance); - bool progress = vectorize_block(nir, nir_start_block(impl), instr_set, + bool progress = vectorize_block(nir_start_block(impl), instr_set, filter, data); if (progress) { @@ -432,14 +424,14 @@ nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl, } bool -nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter, +nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter, void *data) { bool progress = false; nir_foreach_function(function, shader) { if (function->impl) - progress |= nir_opt_vectorize_impl(shader, function->impl, filter, data); + progress |= nir_opt_vectorize_impl(function->impl, filter, data); } return progress; diff --git a/src/gallium/auxiliary/nir/nir_to_tgsi.c b/src/gallium/auxiliary/nir/nir_to_tgsi.c index 2993fc27e22..371b43c8a8e 100644 --- a/src/gallium/auxiliary/nir/nir_to_tgsi.c +++ b/src/gallium/auxiliary/nir/nir_to_tgsi.c @@ -3067,11 +3067,11 @@ type_size(const struct glsl_type *type, bool bindless) /* Allow vectorizing of ALU instructions, but avoid vectorizing past what we * can handle for 64-bit values in TGSI. */ -static bool -ntt_should_vectorize_instr(const nir_instr *instr, void *data) +static uint8_t +ntt_should_vectorize_instr(const nir_instr *instr, const void *data) { if (instr->type != nir_instr_type_alu) - return false; + return 0; nir_alu_instr *alu = nir_instr_as_alu(instr); @@ -3085,7 +3085,7 @@ ntt_should_vectorize_instr(const nir_instr *instr, void *data) * * https://gitlab.freedesktop.org/virgl/virglrenderer/-/issues/195 */ - return false; + return 1; default: break; @@ -3102,10 +3102,10 @@ ntt_should_vectorize_instr(const nir_instr *instr, void *data) * 64-bit instrs in the first place, I don't see much reason to care about * this. */ - return false; + return 1; } - return true; + return 4; } static bool diff --git a/src/gallium/drivers/radeonsi/si_shader_nir.c b/src/gallium/drivers/radeonsi/si_shader_nir.c index 8b267023504..ec1013f8d74 100644 --- a/src/gallium/drivers/radeonsi/si_shader_nir.c +++ b/src/gallium/drivers/radeonsi/si_shader_nir.c @@ -43,6 +43,18 @@ static bool si_alu_to_scalar_filter(const nir_instr *instr, const void *data) return true; } +static uint8_t si_vectorize_callback(const nir_instr *instr, const void *data) +{ + if (instr->type != nir_instr_type_alu) + return 0; + + nir_alu_instr *alu = nir_instr_as_alu(instr); + if (nir_dest_bit_size(alu->dest.dest) == 16) + return 2; + + return 1; +} + void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first) { bool progress; @@ -114,7 +126,7 @@ void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first) NIR_PASS_V(nir, nir_opt_move_discards_to_top); if (sscreen->options.fp16) - NIR_PASS(progress, nir, nir_opt_vectorize, NULL, NULL); + NIR_PASS(progress, nir, nir_opt_vectorize, si_vectorize_callback, NULL); } while (progress); NIR_PASS_V(nir, nir_lower_var_copies); diff --git a/src/mesa/state_tracker/st_glsl_to_nir.cpp b/src/mesa/state_tracker/st_glsl_to_nir.cpp index aca156de2f0..e6bf9b8fd91 100644 --- a/src/mesa/state_tracker/st_glsl_to_nir.cpp +++ b/src/mesa/state_tracker/st_glsl_to_nir.cpp @@ -517,7 +517,7 @@ st_glsl_to_nir_post_opts(struct st_context *st, struct gl_program *prog, if (nir->options->lower_int64_options) NIR_PASS(lowered_64bit_ops, nir, nir_lower_int64); - if (revectorize) + if (revectorize && !nir->options->vectorize_vec2_16bit) NIR_PASS_V(nir, nir_opt_vectorize, nullptr, nullptr); if (revectorize || lowered_64bit_ops) diff --git a/src/panfrost/bifrost/bifrost_compile.c b/src/panfrost/bifrost/bifrost_compile.c index b39e4b1eb5b..337e3e56454 100644 --- a/src/panfrost/bifrost/bifrost_compile.c +++ b/src/panfrost/bifrost/bifrost_compile.c @@ -4276,12 +4276,12 @@ bi_lower_bit_size(const nir_instr *instr, UNUSED void *data) * (8-bit in Bifrost, 32-bit in NIR TODO - workaround!). Some conversions need * to be scalarized due to type size. */ -static bool -bi_vectorize_filter(const nir_instr *instr, void *data) +static uint8_t +bi_vectorize_filter(const nir_instr *instr, const void *data) { /* Defaults work for everything else */ if (instr->type != nir_instr_type_alu) - return true; + return 0; const nir_alu_instr *alu = nir_instr_as_alu(instr); @@ -4293,10 +4293,17 @@ bi_vectorize_filter(const nir_instr *instr, void *data) case nir_op_ushr: case nir_op_f2i16: case nir_op_f2u16: - return false; + return 1; default: - return true; + break; } + + /* Vectorized instructions cannot write more than 32-bit */ + int dst_bit_size = nir_dest_bit_size(alu->dest.dest); + if (dst_bit_size == 16) + return 2; + else + return 1; } static bool diff --git a/src/panfrost/midgard/midgard_compile.c b/src/panfrost/midgard/midgard_compile.c index 6cd48caf0e5..3c17819750c 100644 --- a/src/panfrost/midgard/midgard_compile.c +++ b/src/panfrost/midgard/midgard_compile.c @@ -303,25 +303,20 @@ mdg_should_scalarize(const nir_instr *instr, const void *_unused) } /* Only vectorize int64 up to vec2 */ -static bool -midgard_vectorize_filter(const nir_instr *instr, void *data) +static uint8_t +midgard_vectorize_filter(const nir_instr *instr, const void *data) { if (instr->type != nir_instr_type_alu) - return true; + return 0; const nir_alu_instr *alu = nir_instr_as_alu(instr); - - unsigned num_components = alu->dest.dest.ssa.num_components; - int src_bit_size = nir_src_bit_size(alu->src[0].src); int dst_bit_size = nir_dest_bit_size(alu->dest.dest); - if (src_bit_size == 64 || dst_bit_size == 64) { - if (num_components > 1) - return false; - } + if (src_bit_size == 64 || dst_bit_size == 64) + return 2; - return true; + return 4; } static void