nir/lower_shader_calls: lower scratch access to format internally

For a follow up optimization, we would like to track scratch loads.
This isn't possible with global load/store intrinsics. So use a couple
of special intrinsic in the pass and only lower it to global
intrinsics at the end.

Signed-off-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Reviewed-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16556>
This commit is contained in:
Lionel Landwerlin
2022-05-18 18:29:10 +03:00
committed by Marge Bot
parent df685b4f9c
commit 5a9f8d21d0
3 changed files with 119 additions and 26 deletions

View File

@@ -644,6 +644,7 @@ visit_intrinsic(nir_shader *shader, nir_intrinsic_instr *instr)
case nir_intrinsic_load_topology_id_intel:
case nir_intrinsic_load_scratch_base_ptr:
case nir_intrinsic_ordered_xfb_counter_add_amd:
case nir_intrinsic_load_stack:
is_divergent = true;
break;

View File

@@ -272,6 +272,9 @@ index("unsigned", "saturate")
# Whether or not trace_ray_intel is synchronous
index("bool", "synchronous")
# Value ID to identify SSA value loaded/stored on the stack
index("unsigned", "value_id")
intrinsic("nop", flags=[CAN_ELIMINATE])
intrinsic("convert_alu_types", dest_comp=0, src_comp=[0],
@@ -1077,6 +1080,17 @@ store("global_2x32", [2], [WRITE_MASK, ACCESS, ALIGN_MUL, ALIGN_OFFSET])
# src[] = { value, offset }.
store("scratch", [1], [ALIGN_MUL, ALIGN_OFFSET, WRITE_MASK])
# Intrinsic to load/store from the call stack.
# BASE is the offset relative to the current position of the stack
# src[] = { }.
intrinsic("load_stack", [], dest_comp=0,
indices=[BASE, ALIGN_MUL, ALIGN_OFFSET, CALL_IDX, VALUE_ID],
flags=[CAN_ELIMINATE])
# src[] = { value }.
intrinsic("store_stack", [0],
indices=[BASE, ALIGN_MUL, ALIGN_OFFSET, WRITE_MASK, CALL_IDX, VALUE_ID])
# A bit field to implement SPIRV FragmentShadingRateKHR
# bit | name | description
# 0 | Vertical2Pixels | Fragment invocation covers 2 pixels vertically

View File

@@ -271,36 +271,27 @@ rewrite_instr_src_from_phi_builder(nir_src *src, void *_pbv_arr)
}
static nir_ssa_def *
spill_fill(nir_builder *before, nir_builder *after, nir_ssa_def *def, unsigned offset,
nir_address_format address_format, unsigned stack_alignment)
spill_fill(nir_builder *before, nir_builder *after, nir_ssa_def *def,
unsigned value_id, unsigned call_idx,
unsigned offset, unsigned stack_alignment)
{
const unsigned comp_size = def->bit_size / 8;
switch(address_format) {
case nir_address_format_32bit_offset:
nir_store_scratch(before, def, nir_imm_int(before, offset),
nir_store_stack(before, def,
.base = offset,
.call_idx = call_idx,
.align_mul = MIN2(comp_size, stack_alignment),
.value_id = value_id,
.write_mask = BITFIELD_MASK(def->num_components));
def = nir_load_scratch(after, def->num_components, def->bit_size,
nir_imm_int(after, offset), .align_mul = MIN2(comp_size, stack_alignment));
break;
case nir_address_format_64bit_global: {
nir_ssa_def *addr = nir_iadd_imm(before, nir_load_scratch_base_ptr(before, 1, 64, 1), offset);
nir_store_global(before, addr, MIN2(comp_size, stack_alignment), def, ~0);
addr = nir_iadd_imm(after, nir_load_scratch_base_ptr(after, 1, 64, 1), offset);
def = nir_load_global(after, addr, MIN2(comp_size, stack_alignment),
def->num_components, def->bit_size);
break;
}
default:
unreachable("Unimplemented address format");
}
return def;
return nir_load_stack(after, def->num_components, def->bit_size,
.base = offset,
.call_idx = call_idx,
.value_id = value_id,
.align_mul = MIN2(comp_size, stack_alignment));
}
static void
spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
nir_address_format address_format,
unsigned stack_alignment)
{
/* TODO: If a SSA def is filled more than once, we probably want to just
@@ -439,8 +430,9 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
const unsigned comp_size = def->bit_size / 8;
offset = ALIGN(offset, comp_size);
def = spill_fill(&before, &after, def, offset,
address_format,stack_alignment);
def = spill_fill(&before, &after, def,
index, call_idx,
offset, stack_alignment);
if (is_bool)
def = nir_b2b1(&after, def);
@@ -1135,6 +1127,88 @@ replace_resume_with_halt(nir_shader *shader, nir_instr *keep)
}
}
struct lower_scratch_state {
nir_address_format address_format;
};
static bool
lower_stack_instr_to_scratch(struct nir_builder *b, nir_instr *instr, void *data)
{
struct lower_scratch_state *state = data;
if (instr->type != nir_instr_type_intrinsic)
return false;
nir_intrinsic_instr *stack = nir_instr_as_intrinsic(instr);
switch (stack->intrinsic) {
case nir_intrinsic_load_stack: {
b->cursor = nir_instr_remove(instr);
nir_ssa_def *data, *old_data = nir_instr_ssa_def(instr);
if (state->address_format == nir_address_format_64bit_global) {
nir_ssa_def *addr = nir_iadd_imm(b,
nir_load_scratch_base_ptr(b, 1, 64, 1),
nir_intrinsic_base(stack));
data = nir_load_global(b, addr,
nir_intrinsic_align_mul(stack),
stack->dest.ssa.num_components,
stack->dest.ssa.bit_size);
} else {
assert(state->address_format == nir_address_format_32bit_offset);
data = nir_load_scratch(b,
old_data->num_components,
old_data->bit_size,
nir_imm_int(b, nir_intrinsic_base(stack)),
.align_mul = nir_intrinsic_align_mul(stack));
}
nir_ssa_def_rewrite_uses(old_data, data);
break;
}
case nir_intrinsic_store_stack: {
b->cursor = nir_instr_remove(instr);
nir_ssa_def *data = stack->src[0].ssa;
if (state->address_format == nir_address_format_64bit_global) {
nir_ssa_def *addr = nir_iadd_imm(b,
nir_load_scratch_base_ptr(b, 1, 64, 1),
nir_intrinsic_base(stack));
nir_store_global(b, addr,
nir_intrinsic_align_mul(stack),
data,
BITFIELD_MASK(data->num_components));
} else {
assert(state->address_format == nir_address_format_32bit_offset);
nir_store_scratch(b, data,
nir_imm_int(b, nir_intrinsic_base(stack)),
.align_mul = nir_intrinsic_align_mul(stack),
.write_mask = BITFIELD_MASK(data->num_components));
}
break;
}
default:
return false;
}
return true;
}
static bool
nir_lower_stack_to_scratch(nir_shader *shader,
nir_address_format address_format)
{
struct lower_scratch_state state = {
.address_format = address_format,
};
return nir_shader_instructions_pass(shader,
lower_stack_instr_to_scratch,
nir_metadata_block_index |
nir_metadata_dominance,
&state);
}
/** Lower shader call instructions to split shaders.
*
* Shader calls can be split into an initial shader and a series of "resume"
@@ -1196,7 +1270,7 @@ nir_lower_shader_calls(nir_shader *shader,
}
NIR_PASS_V(shader, spill_ssa_defs_and_lower_shader_calls,
num_calls, address_format, stack_alignment);
num_calls, stack_alignment);
nir_opt_remove_phis(shader);
@@ -1222,6 +1296,10 @@ nir_lower_shader_calls(nir_shader *shader,
nir_opt_if(resume_shaders[i], nir_opt_if_optimize_phi_true_false);
}
NIR_PASS_V(shader, nir_lower_stack_to_scratch, address_format);
for (unsigned i = 0; i < num_calls; i++)
NIR_PASS_V(resume_shaders[i], nir_lower_stack_to_scratch, address_format);
*resume_shaders_out = resume_shaders;
*num_resume_shaders_out = num_calls;