From 3a6a5281b39bfd7de212e3bb00725895cd74d5fb Mon Sep 17 00:00:00 2001 From: Alyssa Rosenzweig Date: Sat, 3 Dec 2022 16:54:32 -0500 Subject: [PATCH] agx: Lower global loads/stores to AGX versions This lets us do all the needed address arithmetic in a central place. Signed-off-by: Alyssa Rosenzweig Part-of: --- src/asahi/compiler/agx_compile.c | 3 + src/asahi/compiler/agx_compiler.h | 1 + src/asahi/compiler/agx_nir_lower_address.c | 157 +++++++++++++++++++++ src/asahi/compiler/meson.build | 1 + 4 files changed, 162 insertions(+) create mode 100644 src/asahi/compiler/agx_nir_lower_address.c diff --git a/src/asahi/compiler/agx_compile.c b/src/asahi/compiler/agx_compile.c index b688dfe4b42..7448a0bd0e2 100644 --- a/src/asahi/compiler/agx_compile.c +++ b/src/asahi/compiler/agx_compile.c @@ -1709,6 +1709,9 @@ agx_optimize_nir(nir_shader *nir, unsigned *preamble_size) NIR_PASS(progress, nir, nir_opt_loop_unroll); } while (progress); + NIR_PASS_V(nir, agx_nir_lower_address); + NIR_PASS_V(nir, nir_lower_int64); + NIR_PASS_V(nir, agx_nir_opt_preamble, preamble_size); NIR_PASS_V(nir, nir_opt_algebraic_late); NIR_PASS_V(nir, nir_opt_constant_folding); diff --git a/src/asahi/compiler/agx_compiler.h b/src/asahi/compiler/agx_compiler.h index 9538c1e5b15..80c25750178 100644 --- a/src/asahi/compiler/agx_compiler.h +++ b/src/asahi/compiler/agx_compiler.h @@ -832,6 +832,7 @@ bool agx_nir_lower_zs_emit(nir_shader *s); bool agx_nir_lower_array_texture(nir_shader *s); bool agx_nir_opt_preamble(nir_shader *s, unsigned *preamble_size); bool agx_nir_lower_load_mask(nir_shader *shader); +bool agx_nir_lower_address(nir_shader *shader); bool agx_nir_lower_ubo(nir_shader *shader); #ifdef __cplusplus diff --git a/src/asahi/compiler/agx_nir_lower_address.c b/src/asahi/compiler/agx_nir_lower_address.c new file mode 100644 index 00000000000..368046f4eea --- /dev/null +++ b/src/asahi/compiler/agx_nir_lower_address.c @@ -0,0 +1,157 @@ +/* + * Copyright 2022 Alyssa Rosenzweig + * SPDX-License-Identifier: MIT + */ + +#include "compiler/nir/nir_builder.h" +#include "agx_compiler.h" + +/* Results of pattern matching */ +struct match { + nir_ssa_scalar base, offset; + bool has_offset; + bool sign_extend; + + /* Signed shift. A negative shift indicates that the offset needs ushr + * applied. It's cheaper to fold iadd and materialize an extra ushr, than + * to leave the iadd untouched, so this is good. + */ + int8_t shift; +}; + +/* Try to pattern match address calculation */ +static struct match +match_address(nir_ssa_scalar base, int8_t format_shift) +{ + struct match match = {.base = base}; + + /* All address calculations are iadd at the root */ + if (!nir_ssa_scalar_is_alu(base) || + nir_ssa_scalar_alu_op(base) != nir_op_iadd) + return match; + + /* Only 64+32 addition is supported, look for an extension */ + nir_ssa_scalar summands[] = { + nir_ssa_scalar_chase_alu_src(base, 0), + nir_ssa_scalar_chase_alu_src(base, 1), + }; + + for (unsigned i = 0; i < ARRAY_SIZE(summands); ++i) { + if (!nir_ssa_scalar_is_alu(summands[i])) + continue; + + nir_op op = nir_ssa_scalar_alu_op(summands[i]); + + if (op != nir_op_u2u64 && op != nir_op_i2i64) + continue; + + match.base = summands[1 - i]; + match.offset = nir_ssa_scalar_chase_alu_src(summands[i], 0); + match.sign_extend = (op == nir_op_i2i64); + + /* Undo the implicit shift from using as offset */ + match.shift = -format_shift; + + /* Now try to fold in an ishl from the offset */ + if (nir_ssa_scalar_is_alu(match.offset) && + nir_ssa_scalar_alu_op(match.offset) == nir_op_ishl) { + + nir_ssa_scalar shifted = nir_ssa_scalar_chase_alu_src(match.offset, 0); + nir_ssa_scalar shift = nir_ssa_scalar_chase_alu_src(match.offset, 1); + + if (nir_ssa_scalar_is_const(shift)) { + int8_t new_shift = match.shift + nir_ssa_scalar_as_uint(shift); + + /* Only fold in if we wouldn't overflow the lsl field */ + if (new_shift <= 2) { + match.offset = shifted; + match.shift = new_shift; + } + } + } + } + + return match; +} + +static bool +pass(struct nir_builder *b, nir_instr *instr, UNUSED void *data) +{ + if (instr->type != nir_instr_type_intrinsic) + return false; + + nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); + if (intr->intrinsic != nir_intrinsic_load_global && + intr->intrinsic != nir_intrinsic_load_global_constant && + intr->intrinsic != nir_intrinsic_store_global) + return false; + + b->cursor = nir_before_instr(instr); + + unsigned bitsize = intr->intrinsic == nir_intrinsic_store_global + ? nir_src_bit_size(intr->src[0]) + : nir_dest_bit_size(intr->dest); + + /* TODO: Handle more sizes */ + assert(bitsize == 16 || bitsize == 32); + enum pipe_format format = + bitsize == 32 ? PIPE_FORMAT_R32_UINT : PIPE_FORMAT_R16_UINT; + + unsigned format_shift = util_logbase2(util_format_get_blocksize(format)); + + nir_src *orig_offset = nir_get_io_offset_src(intr); + nir_ssa_scalar base = nir_ssa_scalar_resolved(orig_offset->ssa, 0); + struct match match = match_address(base, format_shift); + + nir_ssa_def *offset = + match.offset.def != NULL + ? nir_channel(b, match.offset.def, match.offset.comp) + : nir_imm_int(b, 0); + + /* If we were unable to fold in the shift, insert a right-shift now to undo + * the implicit left shift of the instruction. + */ + if (match.shift < 0) { + if (match.sign_extend) + offset = nir_ishr_imm(b, offset, -match.shift); + else + offset = nir_ushr_imm(b, offset, -match.shift); + + match.shift = 0; + } + + assert(match.shift >= 0); + nir_ssa_def *new_base = nir_channel(b, match.base.def, match.base.comp); + + if (intr->intrinsic == nir_intrinsic_load_global) { + nir_ssa_def *repl = + nir_load_agx(b, nir_dest_num_components(intr->dest), + nir_dest_bit_size(intr->dest), new_base, offset, + .access = nir_intrinsic_access(intr), .base = match.shift, + .format = format, .sign_extend = match.sign_extend); + + nir_ssa_def_rewrite_uses(&intr->dest.ssa, repl); + } else if (intr->intrinsic == nir_intrinsic_load_global_constant) { + nir_ssa_def *repl = nir_load_constant_agx( + b, nir_dest_num_components(intr->dest), nir_dest_bit_size(intr->dest), + new_base, offset, .access = nir_intrinsic_access(intr), + .base = match.shift, .format = format, + .sign_extend = match.sign_extend); + + nir_ssa_def_rewrite_uses(&intr->dest.ssa, repl); + } else { + nir_store_agx(b, intr->src[0].ssa, new_base, offset, + .access = nir_intrinsic_access(intr), .base = match.shift, + .format = format, .sign_extend = match.sign_extend); + } + + nir_instr_remove(instr); + return true; +} + +bool +agx_nir_lower_address(nir_shader *shader) +{ + return nir_shader_instructions_pass( + shader, pass, nir_metadata_block_index | nir_metadata_dominance, NULL); +} diff --git a/src/asahi/compiler/meson.build b/src/asahi/compiler/meson.build index 5f0c708c557..4f449d7a259 100644 --- a/src/asahi/compiler/meson.build +++ b/src/asahi/compiler/meson.build @@ -24,6 +24,7 @@ libasahi_agx_files = files( 'agx_dce.c', 'agx_liveness.c', 'agx_insert_waits.c', + 'agx_nir_lower_address.c', 'agx_nir_lower_zs_emit.c', 'agx_nir_lower_texture.c', 'agx_nir_lower_load_mask.c',