intel/rt: Add support for scratch in ray-tracing shaders
In ray-tracing shader stages, we have a real call stack and so we can't use the normal scratch mechanism. Instead, the invocation's stack lives in a memory region of the RT scratch buffer that sits after the HW ray stacks. We handle this by asking nir_lower_io to lower local variables to 64-bit global memory access. Unlike nir_lower_io for 32-bit offset scratch, when 64-bit global access is requested, nir_lower_io generates an address calculation which starts from a load_scratch_base_ptr. We then lower this intrinsic to the appropriate address calculation in brw_nir_lower_rt_intrinsics. When a COMPUTE_WALKER command is sent to the hardware with the BTD Mode bit set to true, the hardware generates a set of stack IDs, one for each invocation. These then get passed along from one shader invocation to the next as we trace the ray. We can use those stack IDs to figure out which stack our invocation needs to access. Because we may not be the first shader in the stack, there's a per-stack offset that gets stored in the "hotzone". Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7356>
This commit is contained in:

committed by
Marge Bot

parent
2b3f6cdc6c
commit
49778a7253
@@ -37,6 +37,14 @@ lower_rt_intrinsics_impl(nir_function_impl *impl,
|
||||
struct brw_nir_rt_globals_defs globals;
|
||||
brw_nir_rt_load_globals(b, &globals);
|
||||
|
||||
nir_ssa_def *hotzone_addr = brw_nir_rt_sw_hotzone_addr(b, devinfo);
|
||||
nir_ssa_def *hotzone = nir_load_global(b, hotzone_addr, 16, 4, 32);
|
||||
|
||||
nir_ssa_def *thread_stack_base_addr = brw_nir_rt_sw_stack_addr(b, devinfo);
|
||||
nir_ssa_def *stack_base_offset = nir_channel(b, hotzone, 0);
|
||||
nir_ssa_def *stack_base_addr =
|
||||
nir_iadd(b, thread_stack_base_addr, nir_u2u64(b, stack_base_offset));
|
||||
|
||||
nir_foreach_block(block, impl) {
|
||||
nir_foreach_instr_safe(instr, block) {
|
||||
if (instr->type != nir_instr_type_intrinsic)
|
||||
@@ -48,6 +56,11 @@ lower_rt_intrinsics_impl(nir_function_impl *impl,
|
||||
|
||||
nir_ssa_def *sysval = NULL;
|
||||
switch (intrin->intrinsic) {
|
||||
case nir_intrinsic_load_scratch_base_ptr:
|
||||
assert(nir_intrinsic_base(intrin) == 1);
|
||||
sysval = stack_base_addr;
|
||||
break;
|
||||
|
||||
case nir_intrinsic_load_ray_base_mem_addr_intel:
|
||||
sysval = globals.base_mem_addr;
|
||||
break;
|
||||
|
@@ -22,35 +22,124 @@
|
||||
*/
|
||||
|
||||
#include "brw_nir_rt.h"
|
||||
#include "nir_builder.h"
|
||||
|
||||
static bool
|
||||
resize_deref(nir_builder *b, nir_deref_instr *deref,
|
||||
unsigned num_components, unsigned bit_size)
|
||||
{
|
||||
assert(deref->dest.is_ssa);
|
||||
if (deref->dest.ssa.num_components == num_components &&
|
||||
deref->dest.ssa.bit_size == bit_size)
|
||||
return false;
|
||||
|
||||
/* NIR requires array indices have to match the deref bit size */
|
||||
if (deref->dest.ssa.bit_size != bit_size &&
|
||||
(deref->deref_type == nir_deref_type_array ||
|
||||
deref->deref_type == nir_deref_type_ptr_as_array)) {
|
||||
b->cursor = nir_before_instr(&deref->instr);
|
||||
assert(deref->arr.index.is_ssa);
|
||||
nir_ssa_def *idx;
|
||||
if (nir_src_is_const(deref->arr.index)) {
|
||||
idx = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index), bit_size);
|
||||
} else {
|
||||
idx = nir_i2i(b, deref->arr.index.ssa, bit_size);
|
||||
}
|
||||
nir_instr_rewrite_src(&deref->instr, &deref->arr.index,
|
||||
nir_src_for_ssa(idx));
|
||||
}
|
||||
|
||||
deref->dest.ssa.num_components = num_components;
|
||||
deref->dest.ssa.bit_size = bit_size;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
resize_function_temp_derefs(nir_shader *shader)
|
||||
{
|
||||
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
|
||||
|
||||
bool progress = false;
|
||||
|
||||
nir_builder b;
|
||||
nir_builder_init(&b, impl);
|
||||
|
||||
nir_foreach_block(block, impl) {
|
||||
nir_foreach_instr_safe(instr, block) {
|
||||
if (instr->type != nir_instr_type_deref)
|
||||
continue;
|
||||
|
||||
nir_deref_instr *deref = nir_instr_as_deref(instr);
|
||||
|
||||
/* We're going to lower all function_temp memory to scratch using
|
||||
* 64-bit addresses. We need to resize all our derefs first or else
|
||||
* nir_lower_explicit_io will have a fit.
|
||||
*/
|
||||
if (nir_deref_mode_is(deref, nir_var_function_temp) &&
|
||||
resize_deref(&b, deref, 1, 64))
|
||||
progress = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (progress) {
|
||||
nir_metadata_preserve(impl, nir_metadata_block_index |
|
||||
nir_metadata_dominance);
|
||||
} else {
|
||||
nir_metadata_preserve(impl, nir_metadata_all);
|
||||
}
|
||||
|
||||
return progress;
|
||||
}
|
||||
|
||||
static void
|
||||
lower_rt_scratch(nir_shader *nir)
|
||||
{
|
||||
/* First, we to ensure all the local variables have explicit types. */
|
||||
NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
|
||||
nir_var_function_temp,
|
||||
glsl_get_natural_size_align_bytes);
|
||||
|
||||
NIR_PASS_V(nir, resize_function_temp_derefs);
|
||||
|
||||
/* Now, lower those variables to 64-bit global memory access */
|
||||
NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_function_temp,
|
||||
nir_address_format_64bit_global);
|
||||
}
|
||||
|
||||
void
|
||||
brw_nir_lower_raygen(nir_shader *nir)
|
||||
{
|
||||
assert(nir->info.stage == MESA_SHADER_RAYGEN);
|
||||
lower_rt_scratch(nir);
|
||||
}
|
||||
|
||||
void
|
||||
brw_nir_lower_any_hit(nir_shader *nir, const struct gen_device_info *devinfo)
|
||||
{
|
||||
assert(nir->info.stage == MESA_SHADER_ANY_HIT);
|
||||
lower_rt_scratch(nir);
|
||||
}
|
||||
|
||||
void
|
||||
brw_nir_lower_closest_hit(nir_shader *nir)
|
||||
{
|
||||
assert(nir->info.stage == MESA_SHADER_CLOSEST_HIT);
|
||||
lower_rt_scratch(nir);
|
||||
}
|
||||
|
||||
void
|
||||
brw_nir_lower_miss(nir_shader *nir)
|
||||
{
|
||||
assert(nir->info.stage == MESA_SHADER_MISS);
|
||||
lower_rt_scratch(nir);
|
||||
}
|
||||
|
||||
void
|
||||
brw_nir_lower_callable(nir_shader *nir)
|
||||
{
|
||||
assert(nir->info.stage == MESA_SHADER_CALLABLE);
|
||||
lower_rt_scratch(nir);
|
||||
}
|
||||
|
||||
void
|
||||
@@ -60,4 +149,5 @@ brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
|
||||
{
|
||||
assert(intersection->info.stage == MESA_SHADER_INTERSECTION);
|
||||
assert(any_hit == NULL || any_hit->info.stage == MESA_SHADER_ANY_HIT);
|
||||
lower_rt_scratch(intersection);
|
||||
}
|
||||
|
@@ -41,6 +41,9 @@ void brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
|
||||
const nir_shader *any_hit,
|
||||
const struct gen_device_info *devinfo);
|
||||
|
||||
/* We require the stack to be 8B aligned at the start of a shader */
|
||||
#define BRW_BTD_STACK_ALIGN 8
|
||||
|
||||
void brw_nir_lower_rt_intrinsics(nir_shader *shader,
|
||||
const struct gen_device_info *devinfo);
|
||||
|
||||
|
@@ -41,6 +41,30 @@ nir_load_global_const_block_intel(nir_builder *b, nir_ssa_def *addr,
|
||||
return &load->dest.ssa;
|
||||
}
|
||||
|
||||
/* We have our own load/store scratch helpers because they emit a global
|
||||
* memory read or write based on the scratch_base_ptr system value rather
|
||||
* than a load/store_scratch intrinsic.
|
||||
*/
|
||||
static inline nir_ssa_def *
|
||||
brw_nir_rt_load_scratch(nir_builder *b, uint32_t offset, unsigned align,
|
||||
unsigned num_components, unsigned bit_size)
|
||||
{
|
||||
nir_ssa_def *addr =
|
||||
nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 1, 64), offset);
|
||||
return nir_load_global(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
|
||||
num_components, bit_size);
|
||||
}
|
||||
|
||||
static inline void
|
||||
brw_nir_rt_store_scratch(nir_builder *b, uint32_t offset, unsigned align,
|
||||
nir_ssa_def *value, nir_component_mask_t write_mask)
|
||||
{
|
||||
nir_ssa_def *addr =
|
||||
nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 1, 64), offset);
|
||||
nir_store_global(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
|
||||
value, write_mask);
|
||||
}
|
||||
|
||||
static inline void
|
||||
assert_def_size(nir_ssa_def *def, unsigned num_components, unsigned bit_size)
|
||||
{
|
||||
|
Reference in New Issue
Block a user