diff --git a/src/asahi/compiler/agx_compile.c b/src/asahi/compiler/agx_compile.c index 30df9cee1e4..51ad0ab7144 100644 --- a/src/asahi/compiler/agx_compile.c +++ b/src/asahi/compiler/agx_compile.c @@ -136,6 +136,61 @@ agx_emit_fragment_out(agx_builder *b, nir_intrinsic_instr *instr) b->shader->key->fs.tib_formats[rt]); } +static enum agx_format +agx_format_for_bits(unsigned bits) +{ + switch (bits) { + case 8: return AGX_FORMAT_I8; + case 16: return AGX_FORMAT_I16; + case 32: return AGX_FORMAT_I32; + default: unreachable("Invalid bit size for load/store"); + } +} + +static void +agx_emit_load_ubo(agx_builder *b, nir_intrinsic_instr *instr) +{ + bool kernel_input = (instr->intrinsic == nir_intrinsic_load_kernel_input); + nir_src *offset = nir_get_io_offset_src(instr); + + if (!kernel_input && !nir_src_is_const(instr->src[0])) + unreachable("todo: indirect UBO access"); + + /* Constant offsets for device_load are 16-bit */ + bool offset_is_const = nir_src_is_const(*offset); + assert(offset_is_const && "todo: indirect UBO access"); + int32_t const_offset = offset_is_const ? nir_src_as_int(*offset) : 0; + + /* Offsets are shifted by the type size, so divide that out */ + unsigned bytes = nir_dest_bit_size(instr->dest) / 8; + assert((const_offset & (bytes - 1)) == 0); + const_offset = const_offset / bytes; + int16_t const_as_16 = const_offset; + + /* UBO blocks are specified (kernel inputs are always 0) */ + uint32_t block = kernel_input ? 0 : nir_src_as_uint(instr->src[0]); + + /* Each UBO has a 64-bit = 4 x 16-bit address */ + unsigned num_ubos = b->shader->nir->info.num_ubos; + unsigned base_length = (num_ubos * 4); + + /* Lookup the base address (TODO: indirection) */ + agx_index base = agx_indexed_sysval(b->shader, + AGX_PUSH_UBO_BASES, AGX_SIZE_64, block, base_length); + + /* Load the data */ + assert(instr->num_components <= 4); + + agx_device_load_to(b, agx_dest_index(&instr->dest), + base, + (offset_is_const && (const_offset == const_as_16)) ? + agx_immediate(const_as_16) : agx_mov_imm(b, 32, const_offset), + agx_format_for_bits(nir_dest_bit_size(instr->dest)), + BITFIELD_MASK(instr->num_components), 0); + + agx_wait(b, 0); +} + static void agx_emit_intrinsic(agx_builder *b, nir_intrinsic_instr *instr) { @@ -170,6 +225,11 @@ agx_emit_intrinsic(agx_builder *b, nir_intrinsic_instr *instr) unreachable("Unsupported shader stage"); break; + case nir_intrinsic_load_ubo: + case nir_intrinsic_load_kernel_input: + agx_emit_load_ubo(b, instr); + break; + default: fprintf(stderr, "Unhandled intrinsic %s\n", nir_intrinsic_infos[instr->intrinsic].name); assert(0);