nir: allow specifying filter callback in lower_alu_to_scalar

Set of opcodes doesn't have enough flexibility in certain cases. E.g.
Utgard PP has vector conditional select operation, but condition is always
scalar. Lowering all the vector selects to scalar increases instruction
number, so we need a way to filter only those ops that can't be handled
in hardware.

Reviewed-by: Qiang Yu <yuq825@gmail.com>
Reviewed-by: Eric Anholt <eric@anholt.net>
Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Signed-off-by: Vasily Khoruzhick <anarsoul@gmail.com>
This commit is contained in:
Vasily Khoruzhick
2019-08-29 21:14:54 -07:00
parent f9f7cbc1aa
commit 9367d2ca37
16 changed files with 109 additions and 63 deletions

View File

@@ -200,7 +200,7 @@ radv_optimize_nir(struct nir_shader *shader, bool optimize_conservatively,
NIR_PASS(progress, shader, nir_remove_dead_variables, NIR_PASS(progress, shader, nir_remove_dead_variables,
nir_var_function_temp); nir_var_function_temp);
NIR_PASS_V(shader, nir_lower_alu_to_scalar, NULL); NIR_PASS_V(shader, nir_lower_alu_to_scalar, NULL, NULL);
NIR_PASS_V(shader, nir_lower_phis_to_scalar); NIR_PASS_V(shader, nir_lower_phis_to_scalar);
NIR_PASS(progress, shader, nir_copy_prop); NIR_PASS(progress, shader, nir_copy_prop);

View File

@@ -1382,7 +1382,7 @@ v3d_optimize_nir(struct nir_shader *s)
progress = false; progress = false;
NIR_PASS_V(s, nir_lower_vars_to_ssa); NIR_PASS_V(s, nir_lower_vars_to_ssa);
NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL); NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL, NULL);
NIR_PASS(progress, s, nir_lower_phis_to_scalar); NIR_PASS(progress, s, nir_lower_phis_to_scalar);
NIR_PASS(progress, s, nir_copy_prop); NIR_PASS(progress, s, nir_copy_prop);
NIR_PASS(progress, s, nir_opt_remove_phis); NIR_PASS(progress, s, nir_opt_remove_phis);

View File

@@ -3606,7 +3606,7 @@ bool nir_lower_alu(nir_shader *shader);
bool nir_lower_flrp(nir_shader *shader, unsigned lowering_mask, bool nir_lower_flrp(nir_shader *shader, unsigned lowering_mask,
bool always_precise, bool have_ffma); bool always_precise, bool have_ffma);
bool nir_lower_alu_to_scalar(nir_shader *shader, BITSET_WORD *lower_set); bool nir_lower_alu_to_scalar(nir_shader *shader, nir_instr_filter_cb cb, const void *data);
bool nir_lower_bool_to_float(nir_shader *shader); bool nir_lower_bool_to_float(nir_shader *shader);
bool nir_lower_bool_to_int32(nir_shader *shader); bool nir_lower_bool_to_int32(nir_shader *shader);
bool nir_lower_int_to_float(nir_shader *shader); bool nir_lower_int_to_float(nir_shader *shader);

View File

@@ -24,6 +24,11 @@
#include "nir.h" #include "nir.h"
#include "nir_builder.h" #include "nir_builder.h"
struct alu_to_scalar_data {
nir_instr_filter_cb cb;
const void *data;
};
/** @file nir_lower_alu_to_scalar.c /** @file nir_lower_alu_to_scalar.c
* *
* Replaces nir_alu_instr operations with more than one channel used in the * Replaces nir_alu_instr operations with more than one channel used in the
@@ -89,9 +94,9 @@ lower_reduction(nir_alu_instr *alu, nir_op chan_op, nir_op merge_op,
} }
static nir_ssa_def * static nir_ssa_def *
lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_state) lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
{ {
BITSET_WORD *lower_set = _state; struct alu_to_scalar_data *data = _data;
nir_alu_instr *alu = nir_instr_as_alu(instr); nir_alu_instr *alu = nir_instr_as_alu(instr);
unsigned num_src = nir_op_infos[alu->op].num_inputs; unsigned num_src = nir_op_infos[alu->op].num_inputs;
unsigned i, chan; unsigned i, chan;
@@ -102,7 +107,7 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_state)
b->cursor = nir_before_instr(&alu->instr); b->cursor = nir_before_instr(&alu->instr);
b->exact = alu->exact; b->exact = alu->exact;
if (lower_set && !BITSET_TEST(lower_set, alu->op)) if (data->cb && !data->cb(instr, data->data))
return NULL; return NULL;
#define LOWER_REDUCTION(name, chan, merge) \ #define LOWER_REDUCTION(name, chan, merge) \
@@ -246,10 +251,15 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_state)
} }
bool bool
nir_lower_alu_to_scalar(nir_shader *shader, BITSET_WORD *lower_set) nir_lower_alu_to_scalar(nir_shader *shader, nir_instr_filter_cb cb, const void *_data)
{ {
struct alu_to_scalar_data data = {
.cb = cb,
.data = _data,
};
return nir_shader_lower_instructions(shader, return nir_shader_lower_instructions(shader,
inst_is_vector_alu, inst_is_vector_alu,
lower_alu_instr_scalar, lower_alu_instr_scalar,
lower_set); &data);
} }

View File

@@ -124,7 +124,7 @@ ir3_optimize_loop(nir_shader *s)
OPT_V(s, nir_lower_vars_to_ssa); OPT_V(s, nir_lower_vars_to_ssa);
progress |= OPT(s, nir_opt_copy_prop_vars); progress |= OPT(s, nir_opt_copy_prop_vars);
progress |= OPT(s, nir_opt_dead_write_vars); progress |= OPT(s, nir_opt_dead_write_vars);
progress |= OPT(s, nir_lower_alu_to_scalar, NULL); progress |= OPT(s, nir_lower_alu_to_scalar, NULL, NULL);
progress |= OPT(s, nir_lower_phis_to_scalar); progress |= OPT(s, nir_lower_phis_to_scalar);
progress |= OPT(s, nir_copy_prop); progress |= OPT(s, nir_copy_prop);

View File

@@ -2559,7 +2559,7 @@ ttn_optimize_nir(nir_shader *nir, bool scalar)
NIR_PASS_V(nir, nir_lower_vars_to_ssa); NIR_PASS_V(nir, nir_lower_vars_to_ssa);
if (scalar) { if (scalar) {
NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL); NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
NIR_PASS_V(nir, nir_lower_phis_to_scalar); NIR_PASS_V(nir, nir_lower_phis_to_scalar);
} }

View File

@@ -208,25 +208,34 @@ etna_lower_io(nir_shader *shader, struct etna_shader_variant *v)
} }
} }
static void static bool
etna_lower_alu_to_scalar(nir_shader *shader, const struct etna_specs *specs) etna_alu_to_scalar_filter_cb(const nir_instr *instr, const void *data)
{ {
BITSET_DECLARE(scalar_ops, nir_num_opcodes); const struct etna_specs *specs = data;
BITSET_ZERO(scalar_ops);
BITSET_SET(scalar_ops, nir_op_frsq); if (instr->type != nir_instr_type_alu)
BITSET_SET(scalar_ops, nir_op_frcp); return false;
BITSET_SET(scalar_ops, nir_op_flog2);
BITSET_SET(scalar_ops, nir_op_fexp2);
BITSET_SET(scalar_ops, nir_op_fsqrt);
BITSET_SET(scalar_ops, nir_op_fcos);
BITSET_SET(scalar_ops, nir_op_fsin);
BITSET_SET(scalar_ops, nir_op_fdiv);
if (!specs->has_halti2_instructions) nir_alu_instr *alu = nir_instr_as_alu(instr);
BITSET_SET(scalar_ops, nir_op_fdot2); switch (alu->op) {
case nir_op_frsq:
case nir_op_frcp:
case nir_op_flog2:
case nir_op_fexp2:
case nir_op_fsqrt:
case nir_op_fcos:
case nir_op_fsin:
case nir_op_fdiv:
return true;
case nir_op_fdot2:
if (!specs->has_halti2_instructions)
return true;
break;
default:
break;
}
nir_lower_alu_to_scalar(shader, scalar_ops); return false;
} }
static void static void
@@ -607,7 +616,7 @@ etna_compile_shader_nir(struct etna_shader_variant *v)
OPT_V(s, nir_lower_vars_to_ssa); OPT_V(s, nir_lower_vars_to_ssa);
OPT_V(s, nir_lower_indirect_derefs, nir_var_all); OPT_V(s, nir_lower_indirect_derefs, nir_var_all);
OPT_V(s, nir_lower_tex, &(struct nir_lower_tex_options) { .lower_txp = ~0u }); OPT_V(s, nir_lower_tex, &(struct nir_lower_tex_options) { .lower_txp = ~0u });
OPT_V(s, etna_lower_alu_to_scalar, specs); OPT_V(s, nir_lower_alu_to_scalar, etna_alu_to_scalar_filter_cb, specs);
etna_optimize_loop(s); etna_optimize_loop(s);
@@ -627,7 +636,7 @@ etna_compile_shader_nir(struct etna_shader_variant *v)
nir_print_shader(s, stdout); nir_print_shader(s, stdout);
while( OPT(s, nir_opt_vectorize) ); while( OPT(s, nir_opt_vectorize) );
OPT_V(s, etna_lower_alu_to_scalar, specs); OPT_V(s, nir_lower_alu_to_scalar, etna_alu_to_scalar_filter_cb, specs);
NIR_PASS_V(s, nir_remove_dead_variables, nir_var_function_temp); NIR_PASS_V(s, nir_remove_dead_variables, nir_var_function_temp);
NIR_PASS_V(s, nir_opt_algebraic_late); NIR_PASS_V(s, nir_opt_algebraic_late);

View File

@@ -1062,6 +1062,29 @@ static void cleanup_binning(struct ir2_context *ctx)
ir2_optimize_nir(ctx->nir, false); ir2_optimize_nir(ctx->nir, false);
} }
static bool
ir2_alu_to_scalar_filter_cb(const nir_instr *instr, const void *data)
{
if (instr->type != nir_instr_type_alu)
return false;
nir_alu_instr *alu = nir_instr_as_alu(instr);
switch (alu->op) {
case nir_op_frsq:
case nir_op_frcp:
case nir_op_flog2:
case nir_op_fexp2:
case nir_op_fsqrt:
case nir_op_fcos:
case nir_op_fsin:
return true;
default:
break;
}
return false;
}
void void
ir2_nir_compile(struct ir2_context *ctx, bool binning) ir2_nir_compile(struct ir2_context *ctx, bool binning)
{ {
@@ -1085,17 +1108,7 @@ ir2_nir_compile(struct ir2_context *ctx, bool binning)
OPT_V(ctx->nir, nir_lower_bool_to_float); OPT_V(ctx->nir, nir_lower_bool_to_float);
OPT_V(ctx->nir, nir_lower_to_source_mods, nir_lower_all_source_mods); OPT_V(ctx->nir, nir_lower_to_source_mods, nir_lower_all_source_mods);
/* TODO: static bitset ? */ OPT_V(ctx->nir, nir_lower_alu_to_scalar, ir2_alu_to_scalar_filter_cb, NULL);
BITSET_DECLARE(scalar_ops, nir_num_opcodes);
BITSET_ZERO(scalar_ops);
BITSET_SET(scalar_ops, nir_op_frsq);
BITSET_SET(scalar_ops, nir_op_frcp);
BITSET_SET(scalar_ops, nir_op_flog2);
BITSET_SET(scalar_ops, nir_op_fexp2);
BITSET_SET(scalar_ops, nir_op_fsqrt);
BITSET_SET(scalar_ops, nir_op_fcos);
BITSET_SET(scalar_ops, nir_op_fsin);
OPT_V(ctx->nir, nir_lower_alu_to_scalar, scalar_ops);
OPT_V(ctx->nir, nir_lower_locals_to_regs); OPT_V(ctx->nir, nir_lower_locals_to_regs);

View File

@@ -110,7 +110,7 @@ lima_program_optimize_vs_nir(struct nir_shader *s)
progress = false; progress = false;
NIR_PASS_V(s, nir_lower_vars_to_ssa); NIR_PASS_V(s, nir_lower_vars_to_ssa);
NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL); NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL, NULL);
NIR_PASS(progress, s, nir_lower_phis_to_scalar); NIR_PASS(progress, s, nir_lower_phis_to_scalar);
NIR_PASS(progress, s, nir_copy_prop); NIR_PASS(progress, s, nir_copy_prop);
NIR_PASS(progress, s, nir_opt_remove_phis); NIR_PASS(progress, s, nir_opt_remove_phis);
@@ -145,24 +145,38 @@ lima_program_optimize_vs_nir(struct nir_shader *s)
nir_sweep(s); nir_sweep(s);
} }
void static bool
lima_program_optimize_fs_nir(struct nir_shader *s) lima_alu_to_scalar_filter_cb(const nir_instr *instr, const void *data)
{ {
BITSET_DECLARE(alu_lower, nir_num_opcodes) = {0}; if (instr->type != nir_instr_type_alu)
bool progress; return false;
BITSET_SET(alu_lower, nir_op_frcp); nir_alu_instr *alu = nir_instr_as_alu(instr);
BITSET_SET(alu_lower, nir_op_frsq); switch (alu->op) {
BITSET_SET(alu_lower, nir_op_flog2); case nir_op_frcp:
BITSET_SET(alu_lower, nir_op_fexp2); case nir_op_frsq:
BITSET_SET(alu_lower, nir_op_fsqrt); case nir_op_flog2:
BITSET_SET(alu_lower, nir_op_fsin); case nir_op_fexp2:
BITSET_SET(alu_lower, nir_op_fcos); case nir_op_fsqrt:
case nir_op_fsin:
case nir_op_fcos:
/* nir vec4 fcsel assumes that each component of the condition will be /* nir vec4 fcsel assumes that each component of the condition will be
* used to select the same component from the two options, but lima * used to select the same component from the two options, but lima
* can't implement that since we only have 1 component condition */ * can't implement that since we only have 1 component condition */
BITSET_SET(alu_lower, nir_op_fcsel); case nir_op_fcsel:
BITSET_SET(alu_lower, nir_op_bcsel); case nir_op_bcsel:
return true;
default:
break;
}
return false;
}
void
lima_program_optimize_fs_nir(struct nir_shader *s)
{
bool progress;
NIR_PASS_V(s, nir_lower_fragcoord_wtrans); NIR_PASS_V(s, nir_lower_fragcoord_wtrans);
NIR_PASS_V(s, nir_lower_io, nir_var_all, type_size, 0); NIR_PASS_V(s, nir_lower_io, nir_var_all, type_size, 0);
@@ -178,7 +192,7 @@ lima_program_optimize_fs_nir(struct nir_shader *s)
progress = false; progress = false;
NIR_PASS_V(s, nir_lower_vars_to_ssa); NIR_PASS_V(s, nir_lower_vars_to_ssa);
NIR_PASS(progress, s, nir_lower_alu_to_scalar, alu_lower); NIR_PASS(progress, s, nir_lower_alu_to_scalar, lima_alu_to_scalar_filter_cb, NULL);
NIR_PASS(progress, s, nir_lower_phis_to_scalar); NIR_PASS(progress, s, nir_lower_phis_to_scalar);
NIR_PASS(progress, s, nir_copy_prop); NIR_PASS(progress, s, nir_copy_prop);
NIR_PASS(progress, s, nir_opt_remove_phis); NIR_PASS(progress, s, nir_opt_remove_phis);

View File

@@ -3500,7 +3500,7 @@ Converter::run()
NIR_PASS_V(nir, nir_lower_regs_to_ssa); NIR_PASS_V(nir, nir_lower_regs_to_ssa);
NIR_PASS_V(nir, nir_lower_load_const_to_scalar); NIR_PASS_V(nir, nir_lower_load_const_to_scalar);
NIR_PASS_V(nir, nir_lower_vars_to_ssa); NIR_PASS_V(nir, nir_lower_vars_to_ssa);
NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL); NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
NIR_PASS_V(nir, nir_lower_phis_to_scalar); NIR_PASS_V(nir, nir_lower_phis_to_scalar);
do { do {

View File

@@ -817,7 +817,7 @@ si_nir_opts(struct nir_shader *nir)
NIR_PASS(progress, nir, nir_opt_copy_prop_vars); NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
NIR_PASS(progress, nir, nir_opt_dead_write_vars); NIR_PASS(progress, nir, nir_opt_dead_write_vars);
NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL); NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
NIR_PASS_V(nir, nir_lower_phis_to_scalar); NIR_PASS_V(nir, nir_lower_phis_to_scalar);
/* (Constant) copy propagation is needed for txf with offsets. */ /* (Constant) copy propagation is needed for txf with offsets. */

View File

@@ -1530,7 +1530,7 @@ vc4_optimize_nir(struct nir_shader *s)
progress = false; progress = false;
NIR_PASS_V(s, nir_lower_vars_to_ssa); NIR_PASS_V(s, nir_lower_vars_to_ssa);
NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL); NIR_PASS(progress, s, nir_lower_alu_to_scalar, NULL, NULL);
NIR_PASS(progress, s, nir_lower_phis_to_scalar); NIR_PASS(progress, s, nir_lower_phis_to_scalar);
NIR_PASS(progress, s, nir_copy_prop); NIR_PASS(progress, s, nir_copy_prop);
NIR_PASS(progress, s, nir_opt_remove_phis); NIR_PASS(progress, s, nir_opt_remove_phis);

View File

@@ -518,7 +518,7 @@ brw_nir_optimize(nir_shader *nir, const struct brw_compiler *compiler,
OPT(nir_opt_combine_stores, nir_var_all); OPT(nir_opt_combine_stores, nir_var_all);
if (is_scalar) { if (is_scalar) {
OPT(nir_lower_alu_to_scalar, NULL); OPT(nir_lower_alu_to_scalar, NULL, NULL);
} }
OPT(nir_copy_prop); OPT(nir_copy_prop);
@@ -654,7 +654,7 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir,
const bool is_scalar = compiler->scalar_stage[nir->info.stage]; const bool is_scalar = compiler->scalar_stage[nir->info.stage];
if (is_scalar) { if (is_scalar) {
OPT(nir_lower_alu_to_scalar, NULL); OPT(nir_lower_alu_to_scalar, NULL, NULL);
} }
if (nir->info.stage == MESA_SHADER_GEOMETRY) if (nir->info.stage == MESA_SHADER_GEOMETRY)
@@ -871,7 +871,7 @@ brw_postprocess_nir(nir_shader *nir, const struct brw_compiler *compiler,
OPT(brw_nir_lower_conversions); OPT(brw_nir_lower_conversions);
if (is_scalar) if (is_scalar)
OPT(nir_lower_alu_to_scalar, NULL); OPT(nir_lower_alu_to_scalar, NULL, NULL);
OPT(nir_lower_to_source_mods, nir_lower_all_source_mods); OPT(nir_lower_to_source_mods, nir_lower_all_source_mods);
OPT(nir_copy_prop); OPT(nir_copy_prop);
OPT(nir_opt_dce); OPT(nir_opt_dce);

View File

@@ -247,7 +247,7 @@ st_nir_opts(nir_shader *nir, bool scalar)
NIR_PASS(progress, nir, nir_opt_dead_write_vars); NIR_PASS(progress, nir, nir_opt_dead_write_vars);
if (scalar) { if (scalar) {
NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL); NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
NIR_PASS_V(nir, nir_lower_phis_to_scalar); NIR_PASS_V(nir, nir_lower_phis_to_scalar);
} }
@@ -363,7 +363,7 @@ st_glsl_to_nir(struct st_context *st, struct gl_program *prog,
NIR_PASS_V(nir, nir_lower_var_copies); NIR_PASS_V(nir, nir_lower_var_copies);
if (is_scalar) { if (is_scalar) {
NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL); NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
} }
/* before buffers and vars_to_ssa */ /* before buffers and vars_to_ssa */

View File

@@ -57,7 +57,7 @@ optimize_nir(nir_shader *nir)
NIR_PASS(progress, nir, nir_opt_constant_folding); NIR_PASS(progress, nir, nir_opt_constant_folding);
NIR_PASS(progress, nir, nir_lower_vars_to_ssa); NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
NIR_PASS(progress, nir, nir_lower_alu_to_scalar, NULL); NIR_PASS(progress, nir, nir_lower_alu_to_scalar, NULL, NULL);
NIR_PASS(progress, nir, nir_opt_if, true); NIR_PASS(progress, nir, nir_opt_if, true);
} while (progress); } while (progress);

View File

@@ -59,7 +59,7 @@ compile_shader(char **argv)
NIR_PASS_V(nir[i], nir_split_var_copies); NIR_PASS_V(nir[i], nir_split_var_copies);
NIR_PASS_V(nir[i], nir_lower_var_copies); NIR_PASS_V(nir[i], nir_lower_var_copies);
NIR_PASS_V(nir[i], nir_lower_alu_to_scalar, NULL); NIR_PASS_V(nir[i], nir_lower_alu_to_scalar, NULL, NULL);
/* before buffers and vars_to_ssa */ /* before buffers and vars_to_ssa */
NIR_PASS_V(nir[i], gl_nir_lower_bindless_images); NIR_PASS_V(nir[i], gl_nir_lower_bindless_images);