diff --git a/src/asahi/compiler/agx_compile.c b/src/asahi/compiler/agx_compile.c index 148fe891ab3..ca4b5a7a63b 100644 --- a/src/asahi/compiler/agx_compile.c +++ b/src/asahi/compiler/agx_compile.c @@ -257,6 +257,13 @@ agx_emit_alu(agx_builder *b, nir_alu_instr *instr) return I; } + case nir_op_fsin_agx: + { + agx_index fixup = agx_sin_pt_1(b, s0); + agx_index sinc = agx_sin_pt_2(b, fixup); + return agx_fmul_to(b, dst, sinc, fixup); + } + case nir_op_vec2: case nir_op_vec3: case nir_op_vec4: @@ -403,6 +410,46 @@ glsl_type_size(const struct glsl_type *type, bool bindless) return glsl_count_attribute_slots(type, false); } +static bool +agx_lower_sincos_filter(const nir_instr *instr, UNUSED const void *_) +{ + if (instr->type != nir_instr_type_alu) + return false; + + nir_alu_instr *alu = nir_instr_as_alu(instr); + return alu->op == nir_op_fsin || alu->op == nir_op_fcos; +} + +/* Sine and cosine are implemented via the sin_pt_1 and sin_pt_2 opcodes for + * heavy lifting. sin_pt_2 implements sinc in the first quadrant, expressed in + * turns (sin (tau x) / x), while sin_pt_1 implements a piecewise sign/offset + * fixup to transform a quadrant angle [0, 4] to [-1, 1]. The NIR opcode + * fsin_agx models the fixup, sinc, and multiply to obtain sine, so we just + * need to change units from radians to quadrants modulo turns. Cosine is + * implemented by shifting by one quadrant: cos(x) = sin(x + tau/4). + */ + +static nir_ssa_def * +agx_lower_sincos_impl(struct nir_builder *b, nir_instr *instr, UNUSED void *_) +{ + nir_alu_instr *alu = nir_instr_as_alu(instr); + nir_ssa_def *x = nir_mov_alu(b, alu->src[0], 1); + nir_ssa_def *turns = nir_fmul_imm(b, x, M_1_PI * 0.5f); + + if (alu->op == nir_op_fcos) + turns = nir_fadd_imm(b, turns, 0.25f); + + nir_ssa_def *quadrants = nir_fmul_imm(b, nir_ffract(b, turns), 4.0); + return nir_fsin_agx(b, quadrants); +} + +static bool +agx_lower_sincos(nir_shader *shader) +{ + return nir_shader_lower_instructions(shader, + agx_lower_sincos_filter, agx_lower_sincos_impl, NULL); +} + static void agx_optimize_nir(nir_shader *nir) { @@ -419,6 +466,7 @@ agx_optimize_nir(nir_shader *nir) NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL); NIR_PASS_V(nir, nir_lower_load_const_to_scalar); NIR_PASS_V(nir, nir_lower_flrp, 16 | 32 | 64, false); + NIR_PASS_V(nir, agx_lower_sincos); do { progress = false;