nir/opt_vectorize: add callback for max vectorization width
The callback allows to request different vectorization factors per instruction depending on e.g. bitsize or opcode. This patch also removes using the vectorize_vec2_16bit option from nir_opt_vectorize(). Reviewed-by: Alyssa Rosenzweig <alyssa@collabora.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13080>
This commit is contained in:

committed by
Marge Bot

parent
7ae206d76e
commit
bd151a256e
@@ -4042,14 +4042,16 @@ lower_bit_size_callback(const nir_instr *instr, void *_)
|
||||
return 0;
|
||||
}
|
||||
|
||||
static bool
|
||||
opt_vectorize_callback(const nir_instr *instr, void *_)
|
||||
static uint8_t
|
||||
opt_vectorize_callback(const nir_instr *instr, const void *_)
|
||||
{
|
||||
assert(instr->type == nir_instr_type_alu);
|
||||
nir_alu_instr *alu = nir_instr_as_alu(instr);
|
||||
unsigned bit_size = alu->dest.dest.ssa.bit_size;
|
||||
if (instr->type != nir_instr_type_alu)
|
||||
return 0;
|
||||
|
||||
const nir_alu_instr *alu = nir_instr_as_alu(instr);
|
||||
const unsigned bit_size = alu->dest.dest.ssa.bit_size;
|
||||
if (bit_size != 16)
|
||||
return false;
|
||||
return 1;
|
||||
|
||||
switch (alu->op) {
|
||||
case nir_op_fadd:
|
||||
@@ -4069,12 +4071,12 @@ opt_vectorize_callback(const nir_instr *instr, void *_)
|
||||
case nir_op_imax:
|
||||
case nir_op_umin:
|
||||
case nir_op_umax:
|
||||
return true;
|
||||
return 2;
|
||||
case nir_op_ishl: /* TODO: in NIR, these have 32bit shift operands */
|
||||
case nir_op_ishr: /* while Radeon needs 16bit operands when vectorized */
|
||||
case nir_op_ushr:
|
||||
default:
|
||||
return false;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -3228,6 +3228,15 @@ typedef enum {
|
||||
*/
|
||||
typedef bool (*nir_instr_filter_cb)(const nir_instr *, const void *);
|
||||
|
||||
/** A vectorization width callback
|
||||
*
|
||||
* Returns the maximum vectorization width per instruction.
|
||||
* 0, if the instruction must not be modified.
|
||||
*
|
||||
* The vectorization width must be a power of 2.
|
||||
*/
|
||||
typedef uint8_t (*nir_vectorize_cb)(const nir_instr *, const void *);
|
||||
|
||||
typedef struct nir_shader_compiler_options {
|
||||
bool lower_fdiv;
|
||||
bool lower_ffma16;
|
||||
@@ -3455,7 +3464,11 @@ typedef struct nir_shader_compiler_options {
|
||||
nir_instr_filter_cb lower_to_scalar_filter;
|
||||
|
||||
/**
|
||||
* Whether nir_opt_vectorize should only create 16-bit 2D vectors.
|
||||
* Disables potentially harmful algebraic transformations for architectures
|
||||
* with SIMD-within-a-register semantics.
|
||||
*
|
||||
* Note, to actually vectorize 16bit instructions, use nir_opt_vectorize()
|
||||
* with a suitable callback function.
|
||||
*/
|
||||
bool vectorize_vec2_16bit;
|
||||
|
||||
@@ -5485,9 +5498,7 @@ bool nir_lower_undef_to_zero(nir_shader *shader);
|
||||
|
||||
bool nir_opt_uniform_atomics(nir_shader *shader);
|
||||
|
||||
typedef bool (*nir_opt_vectorize_cb)(const nir_instr *instr, void *data);
|
||||
|
||||
bool nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter,
|
||||
bool nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter,
|
||||
void *data);
|
||||
|
||||
bool nir_opt_conditional_discard(nir_shader *shader);
|
||||
|
@@ -22,6 +22,16 @@
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* nir_opt_vectorize() aims to vectorize ALU instructions.
|
||||
*
|
||||
* The default vectorization width is 4.
|
||||
* If desired, a callback function which returns the max vectorization width
|
||||
* per instruction can be provided.
|
||||
*
|
||||
* The max vectorization width must be a power of 2.
|
||||
*/
|
||||
|
||||
#include "nir.h"
|
||||
#include "nir_vla.h"
|
||||
#include "nir_builder.h"
|
||||
@@ -125,7 +135,7 @@ instrs_equal(const void *data1, const void *data2)
|
||||
}
|
||||
|
||||
static bool
|
||||
instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
|
||||
instr_can_rewrite(nir_instr *instr)
|
||||
{
|
||||
switch (instr->type) {
|
||||
case nir_instr_type_alu: {
|
||||
@@ -139,12 +149,7 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
|
||||
return false;
|
||||
|
||||
/* no need to hash instructions which are already vectorized */
|
||||
if (alu->dest.dest.ssa.num_components >= 4)
|
||||
return false;
|
||||
|
||||
if (vectorize_16bit &&
|
||||
(alu->dest.dest.ssa.num_components >= 2 ||
|
||||
alu->dest.dest.ssa.bit_size != 16))
|
||||
if (alu->dest.dest.ssa.num_components >= instr->pass_flags)
|
||||
return false;
|
||||
|
||||
if (nir_op_infos[alu->op].output_size != 0)
|
||||
@@ -156,8 +161,8 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
|
||||
|
||||
/* don't hash instructions which are already swizzled
|
||||
* outside of max_components: these should better be scalarized */
|
||||
uint32_t mask = vectorize_16bit ? ~1 : ~3;
|
||||
for (unsigned j = 0; j < alu->dest.dest.ssa.num_components; j++) {
|
||||
uint32_t mask = ~(instr->pass_flags - 1);
|
||||
for (unsigned j = 1; j < alu->dest.dest.ssa.num_components; j++) {
|
||||
if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[j] & mask))
|
||||
return false;
|
||||
}
|
||||
@@ -179,10 +184,8 @@ instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
|
||||
* the same instructions into one vectorized instruction. Note that instr1
|
||||
* should dominate instr2.
|
||||
*/
|
||||
|
||||
static nir_instr *
|
||||
instr_try_combine(struct nir_shader *nir, struct set *instr_set,
|
||||
nir_instr *instr1, nir_instr *instr2)
|
||||
instr_try_combine(struct set *instr_set, nir_instr *instr1, nir_instr *instr2)
|
||||
{
|
||||
assert(instr1->type == nir_instr_type_alu);
|
||||
assert(instr2->type == nir_instr_type_alu);
|
||||
@@ -194,14 +197,10 @@ instr_try_combine(struct nir_shader *nir, struct set *instr_set,
|
||||
unsigned alu2_components = alu2->dest.dest.ssa.num_components;
|
||||
unsigned total_components = alu1_components + alu2_components;
|
||||
|
||||
if (total_components > 4)
|
||||
assert(instr1->pass_flags == instr2->pass_flags);
|
||||
if (total_components > instr1->pass_flags)
|
||||
return NULL;
|
||||
|
||||
if (nir->options->vectorize_vec2_16bit) {
|
||||
assert(total_components == 2);
|
||||
assert(alu1->dest.dest.ssa.bit_size == 16);
|
||||
}
|
||||
|
||||
nir_builder b;
|
||||
nir_builder_init(&b, nir_cf_node_get_function(&instr1->block->cf_node));
|
||||
b.cursor = nir_after_instr(instr1);
|
||||
@@ -352,28 +351,23 @@ vec_instr_set_destroy(struct set *instr_set)
|
||||
}
|
||||
|
||||
static bool
|
||||
vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set,
|
||||
nir_instr *instr,
|
||||
nir_opt_vectorize_cb filter, void *data)
|
||||
vec_instr_set_add_or_rewrite(struct set *instr_set, nir_instr *instr,
|
||||
nir_vectorize_cb filter, void *data)
|
||||
{
|
||||
if (!instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit))
|
||||
return false;
|
||||
|
||||
if (filter && !filter(instr, data))
|
||||
return false;
|
||||
|
||||
/* set max vector to instr pass flags: this is used to hash swizzles */
|
||||
instr->pass_flags = nir->options->vectorize_vec2_16bit ? 2 : 4;
|
||||
instr->pass_flags = filter ? filter(instr, data) : 4;
|
||||
assert(util_is_power_of_two_or_zero(instr->pass_flags));
|
||||
|
||||
if (!instr_can_rewrite(instr))
|
||||
return false;
|
||||
|
||||
struct set_entry *entry = _mesa_set_search(instr_set, instr);
|
||||
if (entry) {
|
||||
nir_instr *old_instr = (nir_instr *) entry->key;
|
||||
_mesa_set_remove(instr_set, entry);
|
||||
nir_instr *new_instr = instr_try_combine(nir, instr_set,
|
||||
old_instr, instr);
|
||||
nir_instr *new_instr = instr_try_combine(instr_set, old_instr, instr);
|
||||
if (new_instr) {
|
||||
if (instr_can_rewrite(new_instr, nir->options->vectorize_vec2_16bit) &&
|
||||
(!filter || filter(new_instr, data)))
|
||||
if (instr_can_rewrite(new_instr))
|
||||
_mesa_set_add(instr_set, new_instr);
|
||||
return true;
|
||||
}
|
||||
@@ -384,25 +378,23 @@ vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set,
|
||||
}
|
||||
|
||||
static bool
|
||||
vectorize_block(struct nir_shader *nir, nir_block *block,
|
||||
struct set *instr_set,
|
||||
nir_opt_vectorize_cb filter, void *data)
|
||||
vectorize_block(nir_block *block, struct set *instr_set,
|
||||
nir_vectorize_cb filter, void *data)
|
||||
{
|
||||
bool progress = false;
|
||||
|
||||
nir_foreach_instr_safe(instr, block) {
|
||||
if (vec_instr_set_add_or_rewrite(nir, instr_set, instr, filter, data))
|
||||
if (vec_instr_set_add_or_rewrite(instr_set, instr, filter, data))
|
||||
progress = true;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < block->num_dom_children; i++) {
|
||||
nir_block *child = block->dom_children[i];
|
||||
progress |= vectorize_block(nir, child, instr_set, filter, data);
|
||||
progress |= vectorize_block(child, instr_set, filter, data);
|
||||
}
|
||||
|
||||
nir_foreach_instr_reverse(instr, block) {
|
||||
if (instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit) &&
|
||||
(!filter || filter(instr, data)))
|
||||
if (instr_can_rewrite(instr))
|
||||
_mesa_set_remove_key(instr_set, instr);
|
||||
}
|
||||
|
||||
@@ -410,14 +402,14 @@ vectorize_block(struct nir_shader *nir, nir_block *block,
|
||||
}
|
||||
|
||||
static bool
|
||||
nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl,
|
||||
nir_opt_vectorize_cb filter, void *data)
|
||||
nir_opt_vectorize_impl(nir_function_impl *impl,
|
||||
nir_vectorize_cb filter, void *data)
|
||||
{
|
||||
struct set *instr_set = vec_instr_set_create();
|
||||
|
||||
nir_metadata_require(impl, nir_metadata_dominance);
|
||||
|
||||
bool progress = vectorize_block(nir, nir_start_block(impl), instr_set,
|
||||
bool progress = vectorize_block(nir_start_block(impl), instr_set,
|
||||
filter, data);
|
||||
|
||||
if (progress) {
|
||||
@@ -432,14 +424,14 @@ nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl,
|
||||
}
|
||||
|
||||
bool
|
||||
nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter,
|
||||
nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter,
|
||||
void *data)
|
||||
{
|
||||
bool progress = false;
|
||||
|
||||
nir_foreach_function(function, shader) {
|
||||
if (function->impl)
|
||||
progress |= nir_opt_vectorize_impl(shader, function->impl, filter, data);
|
||||
progress |= nir_opt_vectorize_impl(function->impl, filter, data);
|
||||
}
|
||||
|
||||
return progress;
|
||||
|
@@ -3067,11 +3067,11 @@ type_size(const struct glsl_type *type, bool bindless)
|
||||
/* Allow vectorizing of ALU instructions, but avoid vectorizing past what we
|
||||
* can handle for 64-bit values in TGSI.
|
||||
*/
|
||||
static bool
|
||||
ntt_should_vectorize_instr(const nir_instr *instr, void *data)
|
||||
static uint8_t
|
||||
ntt_should_vectorize_instr(const nir_instr *instr, const void *data)
|
||||
{
|
||||
if (instr->type != nir_instr_type_alu)
|
||||
return false;
|
||||
return 0;
|
||||
|
||||
nir_alu_instr *alu = nir_instr_as_alu(instr);
|
||||
|
||||
@@ -3085,7 +3085,7 @@ ntt_should_vectorize_instr(const nir_instr *instr, void *data)
|
||||
*
|
||||
* https://gitlab.freedesktop.org/virgl/virglrenderer/-/issues/195
|
||||
*/
|
||||
return false;
|
||||
return 1;
|
||||
|
||||
default:
|
||||
break;
|
||||
@@ -3102,10 +3102,10 @@ ntt_should_vectorize_instr(const nir_instr *instr, void *data)
|
||||
* 64-bit instrs in the first place, I don't see much reason to care about
|
||||
* this.
|
||||
*/
|
||||
return false;
|
||||
return 1;
|
||||
}
|
||||
|
||||
return true;
|
||||
return 4;
|
||||
}
|
||||
|
||||
static bool
|
||||
|
@@ -43,6 +43,18 @@ static bool si_alu_to_scalar_filter(const nir_instr *instr, const void *data)
|
||||
return true;
|
||||
}
|
||||
|
||||
static uint8_t si_vectorize_callback(const nir_instr *instr, const void *data)
|
||||
{
|
||||
if (instr->type != nir_instr_type_alu)
|
||||
return 0;
|
||||
|
||||
nir_alu_instr *alu = nir_instr_as_alu(instr);
|
||||
if (nir_dest_bit_size(alu->dest.dest) == 16)
|
||||
return 2;
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first)
|
||||
{
|
||||
bool progress;
|
||||
@@ -114,7 +126,7 @@ void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first)
|
||||
NIR_PASS_V(nir, nir_opt_move_discards_to_top);
|
||||
|
||||
if (sscreen->options.fp16)
|
||||
NIR_PASS(progress, nir, nir_opt_vectorize, NULL, NULL);
|
||||
NIR_PASS(progress, nir, nir_opt_vectorize, si_vectorize_callback, NULL);
|
||||
} while (progress);
|
||||
|
||||
NIR_PASS_V(nir, nir_lower_var_copies);
|
||||
|
@@ -517,7 +517,7 @@ st_glsl_to_nir_post_opts(struct st_context *st, struct gl_program *prog,
|
||||
if (nir->options->lower_int64_options)
|
||||
NIR_PASS(lowered_64bit_ops, nir, nir_lower_int64);
|
||||
|
||||
if (revectorize)
|
||||
if (revectorize && !nir->options->vectorize_vec2_16bit)
|
||||
NIR_PASS_V(nir, nir_opt_vectorize, nullptr, nullptr);
|
||||
|
||||
if (revectorize || lowered_64bit_ops)
|
||||
|
@@ -4276,12 +4276,12 @@ bi_lower_bit_size(const nir_instr *instr, UNUSED void *data)
|
||||
* (8-bit in Bifrost, 32-bit in NIR TODO - workaround!). Some conversions need
|
||||
* to be scalarized due to type size. */
|
||||
|
||||
static bool
|
||||
bi_vectorize_filter(const nir_instr *instr, void *data)
|
||||
static uint8_t
|
||||
bi_vectorize_filter(const nir_instr *instr, const void *data)
|
||||
{
|
||||
/* Defaults work for everything else */
|
||||
if (instr->type != nir_instr_type_alu)
|
||||
return true;
|
||||
return 0;
|
||||
|
||||
const nir_alu_instr *alu = nir_instr_as_alu(instr);
|
||||
|
||||
@@ -4293,10 +4293,17 @@ bi_vectorize_filter(const nir_instr *instr, void *data)
|
||||
case nir_op_ushr:
|
||||
case nir_op_f2i16:
|
||||
case nir_op_f2u16:
|
||||
return false;
|
||||
return 1;
|
||||
default:
|
||||
return true;
|
||||
break;
|
||||
}
|
||||
|
||||
/* Vectorized instructions cannot write more than 32-bit */
|
||||
int dst_bit_size = nir_dest_bit_size(alu->dest.dest);
|
||||
if (dst_bit_size == 16)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}
|
||||
|
||||
static bool
|
||||
|
@@ -303,25 +303,20 @@ mdg_should_scalarize(const nir_instr *instr, const void *_unused)
|
||||
}
|
||||
|
||||
/* Only vectorize int64 up to vec2 */
|
||||
static bool
|
||||
midgard_vectorize_filter(const nir_instr *instr, void *data)
|
||||
static uint8_t
|
||||
midgard_vectorize_filter(const nir_instr *instr, const void *data)
|
||||
{
|
||||
if (instr->type != nir_instr_type_alu)
|
||||
return true;
|
||||
return 0;
|
||||
|
||||
const nir_alu_instr *alu = nir_instr_as_alu(instr);
|
||||
|
||||
unsigned num_components = alu->dest.dest.ssa.num_components;
|
||||
|
||||
int src_bit_size = nir_src_bit_size(alu->src[0].src);
|
||||
int dst_bit_size = nir_dest_bit_size(alu->dest.dest);
|
||||
|
||||
if (src_bit_size == 64 || dst_bit_size == 64) {
|
||||
if (num_components > 1)
|
||||
return false;
|
||||
}
|
||||
if (src_bit_size == 64 || dst_bit_size == 64)
|
||||
return 2;
|
||||
|
||||
return true;
|
||||
return 4;
|
||||
}
|
||||
|
||||
static void
|
||||
|
Reference in New Issue
Block a user