zink: fix/improve handling for multi-component bitfield ops

the original improvement for this correctly handled cases where
the offset/count values were swizzled with .xxxx, but it was broken
for any other swizzling

this adds a nir pass to more easily manipulate the swizzles and rewrite
the multi-component ops as multiple ops with the swizzled params

fixes #6697

Fixes: 8e97f51c67 ("zink: handle swizzled offset/count values for shader bitfield ops")

Reviewed-by: Dave Airlie <airlied@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/18706>
This commit is contained in:
Mike Blumenkrantz
2022-09-20 11:42:41 -04:00
committed by Marge Bot
parent 111bf8bfee
commit af775f842c
2 changed files with 52 additions and 41 deletions

View File

@@ -1944,19 +1944,6 @@ needs_derivative_control(nir_alu_instr *alu)
}
}
static SpvId
unswizzle_src(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId src, unsigned num_components)
{
/* value may have already been cast to ivec, so cast back */
SpvId cast_type = get_uvec_type(ctx, ssa->bit_size, num_components);
src = emit_bitcast(ctx, cast_type, src);
/* extract from swizzled vec */
SpvId type = spirv_builder_type_uint(&ctx->builder, ssa->bit_size);
uint32_t idx = 0;
return spirv_builder_emit_composite_extract(&ctx->builder, type, src, &idx, 1);
}
static void
emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
{
@@ -1973,34 +1960,6 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
if (needs_derivative_control(alu))
spirv_builder_emit_cap(&ctx->builder, SpvCapabilityDerivativeControl);
/* modify params here */
switch (alu->op) {
/* Offset must be an integer type scalar.
* Offset is the lowest-order bit of the bit field.
* It is consumed as an unsigned value.
*
* Count must be an integer type scalar.
*
* if these ops have more than one component in the dest, then their offset and count
* are swizzled like ssa_1.xxx, but only a single scalar can be provided
*/
case nir_op_ubitfield_extract:
case nir_op_ibitfield_extract:
if (num_components > 1) {
src[1] = unswizzle_src(ctx, alu->src[1].src.ssa, src[1], num_components);
src[2] = unswizzle_src(ctx, alu->src[2].src.ssa, src[2], num_components);
}
break;
case nir_op_bitfield_insert:
if (num_components > 1) {
src[2] = unswizzle_src(ctx, alu->src[2].src.ssa, src[2], num_components);
src[3] = unswizzle_src(ctx, alu->src[3].src.ssa, src[3], num_components);
}
break;
default:
break;
}
SpvId result = 0;
switch (alu->op) {
case nir_op_mov:

View File

@@ -21,6 +21,7 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "nir_opcodes.h"
#include "zink_context.h"
#include "zink_compiler.h"
#include "zink_program.h"
@@ -3028,6 +3029,56 @@ match_tex_dests(nir_shader *shader)
return nir_shader_instructions_pass(shader, match_tex_dests_instr, nir_metadata_dominance, NULL);
}
static bool
split_bitfields_instr(nir_builder *b, nir_instr *in, void *data)
{
if (in->type != nir_instr_type_alu)
return false;
nir_alu_instr *alu = nir_instr_as_alu(in);
switch (alu->op) {
case nir_op_ubitfield_extract:
case nir_op_ibitfield_extract:
case nir_op_bitfield_insert:
break;
default:
return false;
}
unsigned num_components = nir_dest_num_components(alu->dest.dest);
if (num_components == 1)
return false;
b->cursor = nir_before_instr(in);
nir_ssa_def *dests[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i < num_components; i++) {
if (alu->op == nir_op_bitfield_insert)
dests[i] = nir_bitfield_insert(b,
nir_channel(b, alu->src[0].src.ssa, alu->src[0].swizzle[i]),
nir_channel(b, alu->src[1].src.ssa, alu->src[1].swizzle[i]),
nir_channel(b, alu->src[2].src.ssa, alu->src[2].swizzle[i]),
nir_channel(b, alu->src[3].src.ssa, alu->src[3].swizzle[i]));
else if (alu->op == nir_op_ubitfield_extract)
dests[i] = nir_ubitfield_extract(b,
nir_channel(b, alu->src[0].src.ssa, alu->src[0].swizzle[i]),
nir_channel(b, alu->src[1].src.ssa, alu->src[1].swizzle[i]),
nir_channel(b, alu->src[2].src.ssa, alu->src[2].swizzle[i]));
else
dests[i] = nir_ibitfield_extract(b,
nir_channel(b, alu->src[0].src.ssa, alu->src[0].swizzle[i]),
nir_channel(b, alu->src[1].src.ssa, alu->src[1].swizzle[i]),
nir_channel(b, alu->src[2].src.ssa, alu->src[2].swizzle[i]));
}
nir_ssa_def *dest = nir_vec(b, dests, num_components);
nir_ssa_def_rewrite_uses_after(&alu->dest.dest.ssa, dest, in);
nir_instr_remove(in);
return true;
}
static bool
split_bitfields(nir_shader *shader)
{
return nir_shader_instructions_pass(shader, split_bitfields_instr, nir_metadata_dominance, NULL);
}
struct zink_shader *
zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
const struct pipe_stream_output_info *so_info)
@@ -3066,6 +3117,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
NIR_PASS_V(nir, nir_lower_regs_to_ssa);
NIR_PASS_V(nir, lower_baseinstance);
NIR_PASS_V(nir, lower_sparse);
NIR_PASS_V(nir, split_bitfields);
if (screen->info.have_EXT_shader_demote_to_helper_invocation) {
NIR_PASS_V(nir, nir_lower_discard_or_demote,