agx: Emit splits for intrinsics

This allows optimizing the extracts.

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16268>
This commit is contained in:
Alyssa Rosenzweig
2022-04-12 19:46:13 -04:00
parent d06394095b
commit 4f78141c77
2 changed files with 72 additions and 59 deletions

View File

@@ -221,8 +221,8 @@ agx_udiv_const(agx_builder *b, agx_index P, uint32_t Q)
} }
/* AGX appears to lack support for vertex attributes. Lower to global loads. */ /* AGX appears to lack support for vertex attributes. Lower to global loads. */
static agx_instr * static void
agx_emit_load_attr(agx_builder *b, nir_intrinsic_instr *instr) agx_emit_load_attr(agx_builder *b, agx_index *dests, nir_intrinsic_instr *instr)
{ {
nir_src *offset_src = nir_get_io_offset_src(instr); nir_src *offset_src = nir_get_io_offset_src(instr);
assert(nir_src_is_const(*offset_src) && "no attribute indirects"); assert(nir_src_is_const(*offset_src) && "no attribute indirects");
@@ -259,31 +259,24 @@ agx_emit_load_attr(agx_builder *b, nir_intrinsic_instr *instr)
/* Load the data */ /* Load the data */
assert(instr->num_components <= 4); assert(instr->num_components <= 4);
bool pad = ((attrib.nr_comps_minus_1 + 1) < instr->num_components); unsigned actual_comps = (attrib.nr_comps_minus_1 + 1);
agx_index real_dest = agx_dest_index(&instr->dest); agx_index vec = agx_vec_for_dest(b->shader, &instr->dest);
agx_index dest = pad ? agx_temp(b->shader, AGX_SIZE_32) : real_dest; agx_device_load_to(b, vec, base, offset, attrib.format,
agx_device_load_to(b, dest, base, offset, attrib.format,
BITFIELD_MASK(attrib.nr_comps_minus_1 + 1), 0); BITFIELD_MASK(attrib.nr_comps_minus_1 + 1), 0);
agx_wait(b, 0); agx_wait(b, 0);
if (pad) { agx_emit_split(b, dests, vec, actual_comps);
agx_index one = agx_mov_imm(b, 32, fui(1.0));
agx_index zero = agx_mov_imm(b, 32, 0);
agx_index channels[4] = { zero, zero, zero, one };
for (unsigned i = 0; i < (attrib.nr_comps_minus_1 + 1); ++i)
channels[i] = agx_p_extract(b, dest, i);
for (unsigned i = instr->num_components; i < 4; ++i)
channels[i] = agx_null();
agx_p_combine_to(b, real_dest, channels[0], channels[1], channels[2], channels[3]);
}
return NULL; agx_index one = agx_mov_imm(b, 32, fui(1.0));
agx_index zero = agx_mov_imm(b, 32, 0);
agx_index default_value[4] = { zero, zero, zero, one };
for (unsigned i = actual_comps; i < instr->num_components; ++i)
dests[i] = default_value[i];
} }
static agx_instr * static void
agx_emit_load_vary_flat(agx_builder *b, nir_intrinsic_instr *instr) agx_emit_load_vary_flat(agx_builder *b, agx_index *dests, nir_intrinsic_instr *instr)
{ {
unsigned components = instr->num_components; unsigned components = instr->num_components;
assert(components >= 1 && components <= 4); assert(components >= 1 && components <= 4);
@@ -293,20 +286,15 @@ agx_emit_load_vary_flat(agx_builder *b, nir_intrinsic_instr *instr)
unsigned imm_index = b->shader->varyings[nir_intrinsic_base(instr)]; unsigned imm_index = b->shader->varyings[nir_intrinsic_base(instr)];
imm_index += nir_src_as_uint(*offset); imm_index += nir_src_as_uint(*offset);
agx_index chan[4] = { agx_null() };
for (unsigned i = 0; i < components; ++i) { for (unsigned i = 0; i < components; ++i) {
/* vec3 for each vertex, unknown what first 2 channels are for */ /* vec3 for each vertex, unknown what first 2 channels are for */
agx_index values = agx_ld_vary_flat(b, agx_immediate(imm_index + i), 1); agx_index values = agx_ld_vary_flat(b, agx_immediate(imm_index + i), 1);
chan[i] = agx_p_extract(b, values, 2); dests[i] = agx_p_extract(b, values, 2);
} }
return agx_p_combine_to(b, agx_dest_index(&instr->dest),
chan[0], chan[1], chan[2], chan[3]);
} }
static agx_instr * static void
agx_emit_load_vary(agx_builder *b, nir_intrinsic_instr *instr) agx_emit_load_vary(agx_builder *b, agx_index *dests, nir_intrinsic_instr *instr)
{ {
ASSERTED unsigned components = instr->num_components; ASSERTED unsigned components = instr->num_components;
ASSERTED nir_intrinsic_instr *parent = nir_src_as_intrinsic(instr->src[0]); ASSERTED nir_intrinsic_instr *parent = nir_src_as_intrinsic(instr->src[0]);
@@ -322,8 +310,9 @@ agx_emit_load_vary(agx_builder *b, nir_intrinsic_instr *instr)
unsigned imm_index = b->shader->varyings[nir_intrinsic_base(instr)]; unsigned imm_index = b->shader->varyings[nir_intrinsic_base(instr)];
imm_index += nir_src_as_uint(*offset) * 4; imm_index += nir_src_as_uint(*offset) * 4;
return agx_ld_vary_to(b, agx_dest_index(&instr->dest), agx_index vec = agx_vec_for_intr(b->shader, instr);
agx_immediate(imm_index), components, true); agx_ld_vary_to(b, vec, agx_immediate(imm_index), components, true);
agx_emit_split(b, dests, vec, components);
} }
static agx_instr * static agx_instr *
@@ -380,8 +369,8 @@ agx_emit_fragment_out(agx_builder *b, nir_intrinsic_instr *instr)
b->shader->key->fs.tib_formats[rt]); b->shader->key->fs.tib_formats[rt]);
} }
static agx_instr * static void
agx_emit_load_tile(agx_builder *b, nir_intrinsic_instr *instr) agx_emit_load_tile(agx_builder *b, agx_index *dests, nir_intrinsic_instr *instr)
{ {
const nir_variable *var = const nir_variable *var =
nir_find_variable_with_driver_location(b->shader->nir, nir_find_variable_with_driver_location(b->shader->nir,
@@ -399,8 +388,9 @@ agx_emit_load_tile(agx_builder *b, nir_intrinsic_instr *instr)
b->shader->did_writeout = true; b->shader->did_writeout = true;
b->shader->out->reads_tib = true; b->shader->out->reads_tib = true;
return agx_ld_tile_to(b, agx_dest_index(&instr->dest), agx_index vec = agx_vec_for_dest(b->shader, &instr->dest);
b->shader->key->fs.tib_formats[rt]); agx_ld_tile_to(b, vec, b->shader->key->fs.tib_formats[rt]);
agx_emit_split(b, dests, vec, 4);
} }
static enum agx_format static enum agx_format
@@ -415,7 +405,7 @@ agx_format_for_bits(unsigned bits)
} }
static agx_instr * static agx_instr *
agx_emit_load_ubo(agx_builder *b, nir_intrinsic_instr *instr) agx_emit_load_ubo(agx_builder *b, agx_index dst, nir_intrinsic_instr *instr)
{ {
bool kernel_input = (instr->intrinsic == nir_intrinsic_load_kernel_input); bool kernel_input = (instr->intrinsic == nir_intrinsic_load_kernel_input);
nir_src *offset = nir_get_io_offset_src(instr); nir_src *offset = nir_get_io_offset_src(instr);
@@ -439,31 +429,27 @@ agx_emit_load_ubo(agx_builder *b, nir_intrinsic_instr *instr)
/* Load the data */ /* Load the data */
assert(instr->num_components <= 4); assert(instr->num_components <= 4);
agx_device_load_to(b, agx_dest_index(&instr->dest), agx_device_load_to(b, dst, base, agx_src_index(offset),
base, agx_src_index(offset),
agx_format_for_bits(nir_dest_bit_size(instr->dest)), agx_format_for_bits(nir_dest_bit_size(instr->dest)),
BITFIELD_MASK(instr->num_components), 0); BITFIELD_MASK(instr->num_components), 0);
agx_wait(b, 0);
agx_emit_cached_split(b, dst, instr->num_components);
return agx_wait(b, 0); return NULL;
} }
static agx_instr * static void
agx_emit_load_frag_coord(agx_builder *b, nir_intrinsic_instr *instr) agx_emit_load_frag_coord(agx_builder *b, agx_index *dests, nir_intrinsic_instr *instr)
{ {
agx_index xy[2]; /* xy */
for (unsigned i = 0; i < 2; ++i) { for (unsigned i = 0; i < 2; ++i) {
xy[i] = agx_fadd(b, agx_convert(b, agx_immediate(AGX_CONVERT_U32_TO_F), dests[i] = agx_fadd(b, agx_convert(b, agx_immediate(AGX_CONVERT_U32_TO_F),
agx_get_sr(b, 32, AGX_SR_THREAD_POSITION_IN_GRID_X + i), agx_get_sr(b, 32, AGX_SR_THREAD_POSITION_IN_GRID_X + i),
AGX_ROUND_RTE), agx_immediate_f(0.5f)); AGX_ROUND_RTE), agx_immediate_f(0.5f));
} }
/* Ordering by the ABI */ dests[2] = agx_ld_vary(b, agx_immediate(1), 1, false); /* z */
agx_index z = agx_ld_vary(b, agx_immediate(1), 1, false); dests[3] = agx_ld_vary(b, agx_immediate(0), 1, false); /* w */
agx_index w = agx_ld_vary(b, agx_immediate(0), 1, false);
return agx_p_combine_to(b, agx_dest_index(&instr->dest),
xy[0], xy[1], z, w);
} }
static agx_instr * static agx_instr *
@@ -500,6 +486,7 @@ agx_emit_intrinsic(agx_builder *b, nir_intrinsic_instr *instr)
agx_index dst = nir_intrinsic_infos[instr->intrinsic].has_dest ? agx_index dst = nir_intrinsic_infos[instr->intrinsic].has_dest ?
agx_dest_index(&instr->dest) : agx_null(); agx_dest_index(&instr->dest) : agx_null();
gl_shader_stage stage = b->shader->stage; gl_shader_stage stage = b->shader->stage;
agx_index dests[4] = { agx_null() };
switch (instr->intrinsic) { switch (instr->intrinsic) {
case nir_intrinsic_load_barycentric_pixel: case nir_intrinsic_load_barycentric_pixel:
@@ -511,16 +498,19 @@ agx_emit_intrinsic(agx_builder *b, nir_intrinsic_instr *instr)
return NULL; return NULL;
case nir_intrinsic_load_interpolated_input: case nir_intrinsic_load_interpolated_input:
assert(stage == MESA_SHADER_FRAGMENT); assert(stage == MESA_SHADER_FRAGMENT);
return agx_emit_load_vary(b, instr); agx_emit_load_vary(b, dests, instr);
break;
case nir_intrinsic_load_input: case nir_intrinsic_load_input:
if (stage == MESA_SHADER_FRAGMENT) if (stage == MESA_SHADER_FRAGMENT)
return agx_emit_load_vary_flat(b, instr); agx_emit_load_vary_flat(b, dests, instr);
else if (stage == MESA_SHADER_VERTEX) else if (stage == MESA_SHADER_VERTEX)
return agx_emit_load_attr(b, instr); agx_emit_load_attr(b, dests, instr);
else else
unreachable("Unsupported shader stage"); unreachable("Unsupported shader stage");
break;
case nir_intrinsic_store_output: case nir_intrinsic_store_output:
if (stage == MESA_SHADER_FRAGMENT) if (stage == MESA_SHADER_FRAGMENT)
return agx_emit_fragment_out(b, instr); return agx_emit_fragment_out(b, instr);
@@ -531,14 +521,16 @@ agx_emit_intrinsic(agx_builder *b, nir_intrinsic_instr *instr)
case nir_intrinsic_load_output: case nir_intrinsic_load_output:
assert(stage == MESA_SHADER_FRAGMENT); assert(stage == MESA_SHADER_FRAGMENT);
return agx_emit_load_tile(b, instr); agx_emit_load_tile(b, dests, instr);
break;
case nir_intrinsic_load_ubo: case nir_intrinsic_load_ubo:
case nir_intrinsic_load_kernel_input: case nir_intrinsic_load_kernel_input:
return agx_emit_load_ubo(b, instr); return agx_emit_load_ubo(b, dst, instr);
case nir_intrinsic_load_frag_coord: case nir_intrinsic_load_frag_coord:
return agx_emit_load_frag_coord(b, instr); agx_emit_load_frag_coord(b, dests, instr);
break;
case nir_intrinsic_discard: case nir_intrinsic_discard:
return agx_emit_discard(b, instr); return agx_emit_discard(b, instr);
@@ -561,6 +553,14 @@ agx_emit_intrinsic(agx_builder *b, nir_intrinsic_instr *instr)
fprintf(stderr, "Unhandled intrinsic %s\n", nir_intrinsic_infos[instr->intrinsic].name); fprintf(stderr, "Unhandled intrinsic %s\n", nir_intrinsic_infos[instr->intrinsic].name);
unreachable("Unhandled intrinsic"); unreachable("Unhandled intrinsic");
} }
/* If we got here, there is a vector destination for the intrinsic composed
* of separate scalars. Its components are specified separately in the dests
* array. We need to combine them so the vector destination itself is valid.
* If only individual components are accessed, this combine will be dead code
* eliminated.
*/
return agx_emit_combine_to(b, dst, dests[0], dests[1], dests[2], dests[3]);
} }
static agx_index static agx_index
@@ -831,7 +831,7 @@ agx_emit_alu(agx_builder *b, nir_alu_instr *instr)
case nir_op_vec2: case nir_op_vec2:
case nir_op_vec3: case nir_op_vec3:
case nir_op_vec4: case nir_op_vec4:
return agx_p_combine_to(b, dst, s0, s1, s2, s3); return agx_emit_combine_to(b, dst, s0, s1, s2, s3);
case nir_op_vec8: case nir_op_vec8:
case nir_op_vec16: case nir_op_vec16:
@@ -966,14 +966,15 @@ agx_emit_tex(agx_builder *b, nir_tex_instr *instr)
} }
} }
agx_texture_sample_to(b, agx_dest_index(&instr->dest), agx_index dst = agx_dest_index(&instr->dest);
coords, lod, texture, sampler, offset, agx_texture_sample_to(b, dst, coords, lod, texture, sampler, offset,
agx_tex_dim(instr->sampler_dim, instr->is_array), agx_tex_dim(instr->sampler_dim, instr->is_array),
agx_lod_mode_for_nir(instr->op), agx_lod_mode_for_nir(instr->op),
0xF, /* TODO: wrmask */ 0xF, /* TODO: wrmask */
0); 0);
agx_wait(b, 0); agx_wait(b, 0);
agx_emit_cached_split(b, dst, 4);
} }
/* NIR loops are treated as a pair of AGX loops: /* NIR loops are treated as a pair of AGX loops:

View File

@@ -469,6 +469,18 @@ agx_dest_index(nir_dest *dst)
agx_size_for_bits(nir_dest_bit_size(*dst))); agx_size_for_bits(nir_dest_bit_size(*dst)));
} }
static inline agx_index
agx_vec_for_dest(agx_context *ctx, nir_dest *dest)
{
return agx_temp(ctx, agx_size_for_bits(nir_dest_bit_size(*dest)));
}
static inline agx_index
agx_vec_for_intr(agx_context *ctx, nir_intrinsic_instr *instr)
{
return agx_vec_for_dest(ctx, &instr->dest);
}
/* Iterators for AGX IR */ /* Iterators for AGX IR */
#define agx_foreach_block(ctx, v) \ #define agx_foreach_block(ctx, v) \