microsoft/compiler: Add a lowering pass for scan ops that aren't supported
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21996>
This commit is contained in:
@@ -2243,3 +2243,103 @@ dxil_nir_split_unaligned_loads_stores(nir_shader *shader, nir_variable_mode mode
|
||||
|
||||
return progress;
|
||||
}
|
||||
|
||||
static void
|
||||
lower_inclusive_to_exclusive(nir_builder *b, nir_intrinsic_instr *intr)
|
||||
{
|
||||
b->cursor = nir_after_instr(&intr->instr);
|
||||
|
||||
nir_op op = nir_intrinsic_reduction_op(intr);
|
||||
intr->intrinsic = nir_intrinsic_exclusive_scan;
|
||||
nir_intrinsic_set_reduction_op(intr, op);
|
||||
|
||||
nir_ssa_def *final_val = nir_build_alu2(b, nir_intrinsic_reduction_op(intr),
|
||||
&intr->dest.ssa, intr->src[0].ssa);
|
||||
nir_ssa_def_rewrite_uses_after(&intr->dest.ssa, final_val, final_val->parent_instr);
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_subgroup_scan(nir_builder *b, nir_instr *instr, void *data)
|
||||
{
|
||||
if (instr->type != nir_instr_type_intrinsic)
|
||||
return false;
|
||||
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
|
||||
switch (intr->intrinsic) {
|
||||
case nir_intrinsic_exclusive_scan:
|
||||
case nir_intrinsic_inclusive_scan:
|
||||
switch ((nir_op)nir_intrinsic_reduction_op(intr)) {
|
||||
case nir_op_iadd:
|
||||
case nir_op_fadd:
|
||||
case nir_op_imul:
|
||||
case nir_op_fmul:
|
||||
if (intr->intrinsic == nir_intrinsic_exclusive_scan)
|
||||
return false;
|
||||
lower_inclusive_to_exclusive(b, intr);
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
b->cursor = nir_before_instr(instr);
|
||||
nir_op op = nir_intrinsic_reduction_op(intr);
|
||||
nir_ssa_def *subgroup_id = nir_build_load_subgroup_invocation(b);
|
||||
nir_ssa_def *active_threads = nir_build_ballot(b, 4, 32, nir_imm_bool(b, true));
|
||||
nir_ssa_def *base_value;
|
||||
uint32_t bit_size = intr->dest.ssa.bit_size;
|
||||
if (op == nir_op_iand || op == nir_op_umin)
|
||||
base_value = nir_imm_intN_t(b, ~0ull, bit_size);
|
||||
else if (op == nir_op_imin)
|
||||
base_value = nir_imm_intN_t(b, (1ull << (bit_size - 1)) - 1, bit_size);
|
||||
else if (op == nir_op_imax)
|
||||
base_value = nir_imm_intN_t(b, 1ull << (bit_size - 1), bit_size);
|
||||
else if (op == nir_op_fmax)
|
||||
base_value = nir_imm_floatN_t(b, -INFINITY, bit_size);
|
||||
else if (op == nir_op_fmin)
|
||||
base_value = nir_imm_floatN_t(b, INFINITY, bit_size);
|
||||
else
|
||||
base_value = nir_imm_intN_t(b, 0, bit_size);
|
||||
|
||||
nir_variable *loop_counter_var = nir_local_variable_create(b->impl, glsl_uint_type(), "subgroup_loop_counter");
|
||||
nir_variable *result_var = nir_local_variable_create(b->impl,
|
||||
glsl_vector_type(nir_get_glsl_base_type_for_nir_type(
|
||||
nir_op_infos[op].input_types[0] | bit_size), 1),
|
||||
"subgroup_loop_result");
|
||||
nir_store_var(b, loop_counter_var, nir_imm_int(b, 0), 1);
|
||||
nir_store_var(b, result_var, base_value, 1);
|
||||
nir_loop *loop = nir_push_loop(b);
|
||||
nir_ssa_def *loop_counter = nir_load_var(b, loop_counter_var);
|
||||
nir_if *nif = nir_push_if(b, intr->intrinsic == nir_intrinsic_inclusive_scan ?
|
||||
nir_ige(b, subgroup_id, loop_counter) :
|
||||
nir_ilt(b, loop_counter, subgroup_id));
|
||||
nir_if *if_active_thread = nir_push_if(b, nir_build_ballot_bitfield_extract(b, 32, active_threads, loop_counter));
|
||||
nir_ssa_def *result = nir_build_alu2(b, op,
|
||||
nir_load_var(b, result_var),
|
||||
nir_build_read_invocation(b, intr->src[0].ssa, loop_counter));
|
||||
nir_store_var(b, result_var, result, 1);
|
||||
nir_pop_if(b, if_active_thread);
|
||||
nir_store_var(b, loop_counter_var, nir_iadd_imm(b, loop_counter, 1), 1);
|
||||
nir_jump(b, nir_jump_continue);
|
||||
nir_pop_if(b, nif);
|
||||
nir_jump(b, nir_jump_break);
|
||||
nir_pop_loop(b, loop);
|
||||
|
||||
result = nir_load_var(b, result_var);
|
||||
nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
dxil_nir_lower_unsupported_subgroup_scan(nir_shader *s)
|
||||
{
|
||||
bool ret = nir_shader_instructions_pass(s, lower_subgroup_scan, nir_metadata_none, NULL);
|
||||
if (ret) {
|
||||
/* Lower the ballot bitfield tests */
|
||||
nir_lower_subgroups_options options = { .ballot_bit_size = 32, .ballot_components = 4 };
|
||||
nir_lower_subgroups(s, &options);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@@ -82,6 +82,7 @@ bool dxil_nir_lower_sample_pos(nir_shader *s);
|
||||
bool dxil_nir_lower_subgroup_id(nir_shader *s);
|
||||
bool dxil_nir_lower_num_subgroups(nir_shader *s);
|
||||
bool dxil_nir_split_unaligned_loads_stores(nir_shader *shader, nir_variable_mode modes);
|
||||
bool dxil_nir_lower_unsupported_subgroup_scan(nir_shader *s);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
Reference in New Issue
Block a user