From ccc9540dae9e0c9b9fb6ecdaabecc748e487f1e9 Mon Sep 17 00:00:00 2001 From: Jesse Natalie Date: Fri, 17 Mar 2023 14:52:27 -0700 Subject: [PATCH] microsoft/compiler: Add a lowering pass for scan ops that aren't supported Part-of: --- src/microsoft/compiler/dxil_nir.c | 100 ++++++++++++++++++++++++++++++ src/microsoft/compiler/dxil_nir.h | 1 + 2 files changed, 101 insertions(+) diff --git a/src/microsoft/compiler/dxil_nir.c b/src/microsoft/compiler/dxil_nir.c index f61a8f60e9d..53bdaf53947 100644 --- a/src/microsoft/compiler/dxil_nir.c +++ b/src/microsoft/compiler/dxil_nir.c @@ -2243,3 +2243,103 @@ dxil_nir_split_unaligned_loads_stores(nir_shader *shader, nir_variable_mode mode return progress; } + +static void +lower_inclusive_to_exclusive(nir_builder *b, nir_intrinsic_instr *intr) +{ + b->cursor = nir_after_instr(&intr->instr); + + nir_op op = nir_intrinsic_reduction_op(intr); + intr->intrinsic = nir_intrinsic_exclusive_scan; + nir_intrinsic_set_reduction_op(intr, op); + + nir_ssa_def *final_val = nir_build_alu2(b, nir_intrinsic_reduction_op(intr), + &intr->dest.ssa, intr->src[0].ssa); + nir_ssa_def_rewrite_uses_after(&intr->dest.ssa, final_val, final_val->parent_instr); +} + +static bool +lower_subgroup_scan(nir_builder *b, nir_instr *instr, void *data) +{ + if (instr->type != nir_instr_type_intrinsic) + return false; + nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); + switch (intr->intrinsic) { + case nir_intrinsic_exclusive_scan: + case nir_intrinsic_inclusive_scan: + switch ((nir_op)nir_intrinsic_reduction_op(intr)) { + case nir_op_iadd: + case nir_op_fadd: + case nir_op_imul: + case nir_op_fmul: + if (intr->intrinsic == nir_intrinsic_exclusive_scan) + return false; + lower_inclusive_to_exclusive(b, intr); + return true; + default: + break; + } + break; + default: + return false; + } + + b->cursor = nir_before_instr(instr); + nir_op op = nir_intrinsic_reduction_op(intr); + nir_ssa_def *subgroup_id = nir_build_load_subgroup_invocation(b); + nir_ssa_def *active_threads = nir_build_ballot(b, 4, 32, nir_imm_bool(b, true)); + nir_ssa_def *base_value; + uint32_t bit_size = intr->dest.ssa.bit_size; + if (op == nir_op_iand || op == nir_op_umin) + base_value = nir_imm_intN_t(b, ~0ull, bit_size); + else if (op == nir_op_imin) + base_value = nir_imm_intN_t(b, (1ull << (bit_size - 1)) - 1, bit_size); + else if (op == nir_op_imax) + base_value = nir_imm_intN_t(b, 1ull << (bit_size - 1), bit_size); + else if (op == nir_op_fmax) + base_value = nir_imm_floatN_t(b, -INFINITY, bit_size); + else if (op == nir_op_fmin) + base_value = nir_imm_floatN_t(b, INFINITY, bit_size); + else + base_value = nir_imm_intN_t(b, 0, bit_size); + + nir_variable *loop_counter_var = nir_local_variable_create(b->impl, glsl_uint_type(), "subgroup_loop_counter"); + nir_variable *result_var = nir_local_variable_create(b->impl, + glsl_vector_type(nir_get_glsl_base_type_for_nir_type( + nir_op_infos[op].input_types[0] | bit_size), 1), + "subgroup_loop_result"); + nir_store_var(b, loop_counter_var, nir_imm_int(b, 0), 1); + nir_store_var(b, result_var, base_value, 1); + nir_loop *loop = nir_push_loop(b); + nir_ssa_def *loop_counter = nir_load_var(b, loop_counter_var); + nir_if *nif = nir_push_if(b, intr->intrinsic == nir_intrinsic_inclusive_scan ? + nir_ige(b, subgroup_id, loop_counter) : + nir_ilt(b, loop_counter, subgroup_id)); + nir_if *if_active_thread = nir_push_if(b, nir_build_ballot_bitfield_extract(b, 32, active_threads, loop_counter)); + nir_ssa_def *result = nir_build_alu2(b, op, + nir_load_var(b, result_var), + nir_build_read_invocation(b, intr->src[0].ssa, loop_counter)); + nir_store_var(b, result_var, result, 1); + nir_pop_if(b, if_active_thread); + nir_store_var(b, loop_counter_var, nir_iadd_imm(b, loop_counter, 1), 1); + nir_jump(b, nir_jump_continue); + nir_pop_if(b, nif); + nir_jump(b, nir_jump_break); + nir_pop_loop(b, loop); + + result = nir_load_var(b, result_var); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, result); + return true; +} + +bool +dxil_nir_lower_unsupported_subgroup_scan(nir_shader *s) +{ + bool ret = nir_shader_instructions_pass(s, lower_subgroup_scan, nir_metadata_none, NULL); + if (ret) { + /* Lower the ballot bitfield tests */ + nir_lower_subgroups_options options = { .ballot_bit_size = 32, .ballot_components = 4 }; + nir_lower_subgroups(s, &options); + } + return ret; +} diff --git a/src/microsoft/compiler/dxil_nir.h b/src/microsoft/compiler/dxil_nir.h index d1f236ae26f..df0d66817c0 100644 --- a/src/microsoft/compiler/dxil_nir.h +++ b/src/microsoft/compiler/dxil_nir.h @@ -82,6 +82,7 @@ bool dxil_nir_lower_sample_pos(nir_shader *s); bool dxil_nir_lower_subgroup_id(nir_shader *s); bool dxil_nir_lower_num_subgroups(nir_shader *s); bool dxil_nir_split_unaligned_loads_stores(nir_shader *shader, nir_variable_mode modes); +bool dxil_nir_lower_unsupported_subgroup_scan(nir_shader *s); #ifdef __cplusplus }