intel/rt: Add support for scratch in ray-tracing shaders

In ray-tracing shader stages, we have a real call stack and so we can't
use the normal scratch mechanism.  Instead, the invocation's stack lives
in a memory region of the RT scratch buffer that sits after the HW ray
stacks.  We handle this by asking nir_lower_io to lower local variables
to 64-bit global memory access.  Unlike nir_lower_io for 32-bit offset
scratch, when 64-bit global access is requested, nir_lower_io generates
an address calculation which starts from a load_scratch_base_ptr.  We
then lower this intrinsic to the appropriate address calculation in
brw_nir_lower_rt_intrinsics.

When a COMPUTE_WALKER command is sent to the hardware with the BTD Mode
bit set to true, the hardware generates a set of stack IDs, one for each
invocation.  These then get passed along from one shader invocation to
the next as we trace the ray.  We can use those stack IDs to figure out
which stack our invocation needs to access.  Because we may not be the
first shader in the stack, there's a per-stack offset that gets stored
in the "hotzone".

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7356>
This commit is contained in:
Jason Ekstrand
2020-08-06 13:16:53 -05:00
committed by Marge Bot
parent 2b3f6cdc6c
commit 49778a7253
4 changed files with 130 additions and 0 deletions

View File

@@ -37,6 +37,14 @@ lower_rt_intrinsics_impl(nir_function_impl *impl,
struct brw_nir_rt_globals_defs globals; struct brw_nir_rt_globals_defs globals;
brw_nir_rt_load_globals(b, &globals); brw_nir_rt_load_globals(b, &globals);
nir_ssa_def *hotzone_addr = brw_nir_rt_sw_hotzone_addr(b, devinfo);
nir_ssa_def *hotzone = nir_load_global(b, hotzone_addr, 16, 4, 32);
nir_ssa_def *thread_stack_base_addr = brw_nir_rt_sw_stack_addr(b, devinfo);
nir_ssa_def *stack_base_offset = nir_channel(b, hotzone, 0);
nir_ssa_def *stack_base_addr =
nir_iadd(b, thread_stack_base_addr, nir_u2u64(b, stack_base_offset));
nir_foreach_block(block, impl) { nir_foreach_block(block, impl) {
nir_foreach_instr_safe(instr, block) { nir_foreach_instr_safe(instr, block) {
if (instr->type != nir_instr_type_intrinsic) if (instr->type != nir_instr_type_intrinsic)
@@ -48,6 +56,11 @@ lower_rt_intrinsics_impl(nir_function_impl *impl,
nir_ssa_def *sysval = NULL; nir_ssa_def *sysval = NULL;
switch (intrin->intrinsic) { switch (intrin->intrinsic) {
case nir_intrinsic_load_scratch_base_ptr:
assert(nir_intrinsic_base(intrin) == 1);
sysval = stack_base_addr;
break;
case nir_intrinsic_load_ray_base_mem_addr_intel: case nir_intrinsic_load_ray_base_mem_addr_intel:
sysval = globals.base_mem_addr; sysval = globals.base_mem_addr;
break; break;

View File

@@ -22,35 +22,124 @@
*/ */
#include "brw_nir_rt.h" #include "brw_nir_rt.h"
#include "nir_builder.h"
static bool
resize_deref(nir_builder *b, nir_deref_instr *deref,
unsigned num_components, unsigned bit_size)
{
assert(deref->dest.is_ssa);
if (deref->dest.ssa.num_components == num_components &&
deref->dest.ssa.bit_size == bit_size)
return false;
/* NIR requires array indices have to match the deref bit size */
if (deref->dest.ssa.bit_size != bit_size &&
(deref->deref_type == nir_deref_type_array ||
deref->deref_type == nir_deref_type_ptr_as_array)) {
b->cursor = nir_before_instr(&deref->instr);
assert(deref->arr.index.is_ssa);
nir_ssa_def *idx;
if (nir_src_is_const(deref->arr.index)) {
idx = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index), bit_size);
} else {
idx = nir_i2i(b, deref->arr.index.ssa, bit_size);
}
nir_instr_rewrite_src(&deref->instr, &deref->arr.index,
nir_src_for_ssa(idx));
}
deref->dest.ssa.num_components = num_components;
deref->dest.ssa.bit_size = bit_size;
return true;
}
static bool
resize_function_temp_derefs(nir_shader *shader)
{
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
bool progress = false;
nir_builder b;
nir_builder_init(&b, impl);
nir_foreach_block(block, impl) {
nir_foreach_instr_safe(instr, block) {
if (instr->type != nir_instr_type_deref)
continue;
nir_deref_instr *deref = nir_instr_as_deref(instr);
/* We're going to lower all function_temp memory to scratch using
* 64-bit addresses. We need to resize all our derefs first or else
* nir_lower_explicit_io will have a fit.
*/
if (nir_deref_mode_is(deref, nir_var_function_temp) &&
resize_deref(&b, deref, 1, 64))
progress = true;
}
}
if (progress) {
nir_metadata_preserve(impl, nir_metadata_block_index |
nir_metadata_dominance);
} else {
nir_metadata_preserve(impl, nir_metadata_all);
}
return progress;
}
static void
lower_rt_scratch(nir_shader *nir)
{
/* First, we to ensure all the local variables have explicit types. */
NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
nir_var_function_temp,
glsl_get_natural_size_align_bytes);
NIR_PASS_V(nir, resize_function_temp_derefs);
/* Now, lower those variables to 64-bit global memory access */
NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_function_temp,
nir_address_format_64bit_global);
}
void void
brw_nir_lower_raygen(nir_shader *nir) brw_nir_lower_raygen(nir_shader *nir)
{ {
assert(nir->info.stage == MESA_SHADER_RAYGEN); assert(nir->info.stage == MESA_SHADER_RAYGEN);
lower_rt_scratch(nir);
} }
void void
brw_nir_lower_any_hit(nir_shader *nir, const struct gen_device_info *devinfo) brw_nir_lower_any_hit(nir_shader *nir, const struct gen_device_info *devinfo)
{ {
assert(nir->info.stage == MESA_SHADER_ANY_HIT); assert(nir->info.stage == MESA_SHADER_ANY_HIT);
lower_rt_scratch(nir);
} }
void void
brw_nir_lower_closest_hit(nir_shader *nir) brw_nir_lower_closest_hit(nir_shader *nir)
{ {
assert(nir->info.stage == MESA_SHADER_CLOSEST_HIT); assert(nir->info.stage == MESA_SHADER_CLOSEST_HIT);
lower_rt_scratch(nir);
} }
void void
brw_nir_lower_miss(nir_shader *nir) brw_nir_lower_miss(nir_shader *nir)
{ {
assert(nir->info.stage == MESA_SHADER_MISS); assert(nir->info.stage == MESA_SHADER_MISS);
lower_rt_scratch(nir);
} }
void void
brw_nir_lower_callable(nir_shader *nir) brw_nir_lower_callable(nir_shader *nir)
{ {
assert(nir->info.stage == MESA_SHADER_CALLABLE); assert(nir->info.stage == MESA_SHADER_CALLABLE);
lower_rt_scratch(nir);
} }
void void
@@ -60,4 +149,5 @@ brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
{ {
assert(intersection->info.stage == MESA_SHADER_INTERSECTION); assert(intersection->info.stage == MESA_SHADER_INTERSECTION);
assert(any_hit == NULL || any_hit->info.stage == MESA_SHADER_ANY_HIT); assert(any_hit == NULL || any_hit->info.stage == MESA_SHADER_ANY_HIT);
lower_rt_scratch(intersection);
} }

View File

@@ -41,6 +41,9 @@ void brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
const nir_shader *any_hit, const nir_shader *any_hit,
const struct gen_device_info *devinfo); const struct gen_device_info *devinfo);
/* We require the stack to be 8B aligned at the start of a shader */
#define BRW_BTD_STACK_ALIGN 8
void brw_nir_lower_rt_intrinsics(nir_shader *shader, void brw_nir_lower_rt_intrinsics(nir_shader *shader,
const struct gen_device_info *devinfo); const struct gen_device_info *devinfo);

View File

@@ -41,6 +41,30 @@ nir_load_global_const_block_intel(nir_builder *b, nir_ssa_def *addr,
return &load->dest.ssa; return &load->dest.ssa;
} }
/* We have our own load/store scratch helpers because they emit a global
* memory read or write based on the scratch_base_ptr system value rather
* than a load/store_scratch intrinsic.
*/
static inline nir_ssa_def *
brw_nir_rt_load_scratch(nir_builder *b, uint32_t offset, unsigned align,
unsigned num_components, unsigned bit_size)
{
nir_ssa_def *addr =
nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 1, 64), offset);
return nir_load_global(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
num_components, bit_size);
}
static inline void
brw_nir_rt_store_scratch(nir_builder *b, uint32_t offset, unsigned align,
nir_ssa_def *value, nir_component_mask_t write_mask)
{
nir_ssa_def *addr =
nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 1, 64), offset);
nir_store_global(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
value, write_mask);
}
static inline void static inline void
assert_def_size(nir_ssa_def *def, unsigned num_components, unsigned bit_size) assert_def_size(nir_ssa_def *def, unsigned num_components, unsigned bit_size)
{ {