microsoft/compiler: Add a lowering pass for scan ops that aren't supported

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21996>
This commit is contained in:
Jesse Natalie
2023-03-17 14:52:27 -07:00
committed by Marge Bot
parent 981fe2bf42
commit ccc9540dae
2 changed files with 101 additions and 0 deletions

View File

@@ -2243,3 +2243,103 @@ dxil_nir_split_unaligned_loads_stores(nir_shader *shader, nir_variable_mode mode
return progress;
}
static void
lower_inclusive_to_exclusive(nir_builder *b, nir_intrinsic_instr *intr)
{
b->cursor = nir_after_instr(&intr->instr);
nir_op op = nir_intrinsic_reduction_op(intr);
intr->intrinsic = nir_intrinsic_exclusive_scan;
nir_intrinsic_set_reduction_op(intr, op);
nir_ssa_def *final_val = nir_build_alu2(b, nir_intrinsic_reduction_op(intr),
&intr->dest.ssa, intr->src[0].ssa);
nir_ssa_def_rewrite_uses_after(&intr->dest.ssa, final_val, final_val->parent_instr);
}
static bool
lower_subgroup_scan(nir_builder *b, nir_instr *instr, void *data)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
switch (intr->intrinsic) {
case nir_intrinsic_exclusive_scan:
case nir_intrinsic_inclusive_scan:
switch ((nir_op)nir_intrinsic_reduction_op(intr)) {
case nir_op_iadd:
case nir_op_fadd:
case nir_op_imul:
case nir_op_fmul:
if (intr->intrinsic == nir_intrinsic_exclusive_scan)
return false;
lower_inclusive_to_exclusive(b, intr);
return true;
default:
break;
}
break;
default:
return false;
}
b->cursor = nir_before_instr(instr);
nir_op op = nir_intrinsic_reduction_op(intr);
nir_ssa_def *subgroup_id = nir_build_load_subgroup_invocation(b);
nir_ssa_def *active_threads = nir_build_ballot(b, 4, 32, nir_imm_bool(b, true));
nir_ssa_def *base_value;
uint32_t bit_size = intr->dest.ssa.bit_size;
if (op == nir_op_iand || op == nir_op_umin)
base_value = nir_imm_intN_t(b, ~0ull, bit_size);
else if (op == nir_op_imin)
base_value = nir_imm_intN_t(b, (1ull << (bit_size - 1)) - 1, bit_size);
else if (op == nir_op_imax)
base_value = nir_imm_intN_t(b, 1ull << (bit_size - 1), bit_size);
else if (op == nir_op_fmax)
base_value = nir_imm_floatN_t(b, -INFINITY, bit_size);
else if (op == nir_op_fmin)
base_value = nir_imm_floatN_t(b, INFINITY, bit_size);
else
base_value = nir_imm_intN_t(b, 0, bit_size);
nir_variable *loop_counter_var = nir_local_variable_create(b->impl, glsl_uint_type(), "subgroup_loop_counter");
nir_variable *result_var = nir_local_variable_create(b->impl,
glsl_vector_type(nir_get_glsl_base_type_for_nir_type(
nir_op_infos[op].input_types[0] | bit_size), 1),
"subgroup_loop_result");
nir_store_var(b, loop_counter_var, nir_imm_int(b, 0), 1);
nir_store_var(b, result_var, base_value, 1);
nir_loop *loop = nir_push_loop(b);
nir_ssa_def *loop_counter = nir_load_var(b, loop_counter_var);
nir_if *nif = nir_push_if(b, intr->intrinsic == nir_intrinsic_inclusive_scan ?
nir_ige(b, subgroup_id, loop_counter) :
nir_ilt(b, loop_counter, subgroup_id));
nir_if *if_active_thread = nir_push_if(b, nir_build_ballot_bitfield_extract(b, 32, active_threads, loop_counter));
nir_ssa_def *result = nir_build_alu2(b, op,
nir_load_var(b, result_var),
nir_build_read_invocation(b, intr->src[0].ssa, loop_counter));
nir_store_var(b, result_var, result, 1);
nir_pop_if(b, if_active_thread);
nir_store_var(b, loop_counter_var, nir_iadd_imm(b, loop_counter, 1), 1);
nir_jump(b, nir_jump_continue);
nir_pop_if(b, nif);
nir_jump(b, nir_jump_break);
nir_pop_loop(b, loop);
result = nir_load_var(b, result_var);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
return true;
}
bool
dxil_nir_lower_unsupported_subgroup_scan(nir_shader *s)
{
bool ret = nir_shader_instructions_pass(s, lower_subgroup_scan, nir_metadata_none, NULL);
if (ret) {
/* Lower the ballot bitfield tests */
nir_lower_subgroups_options options = { .ballot_bit_size = 32, .ballot_components = 4 };
nir_lower_subgroups(s, &options);
}
return ret;
}

View File

@@ -82,6 +82,7 @@ bool dxil_nir_lower_sample_pos(nir_shader *s);
bool dxil_nir_lower_subgroup_id(nir_shader *s);
bool dxil_nir_lower_num_subgroups(nir_shader *s);
bool dxil_nir_split_unaligned_loads_stores(nir_shader *shader, nir_variable_mode modes);
bool dxil_nir_lower_unsupported_subgroup_scan(nir_shader *s);
#ifdef __cplusplus
}