intel/rt: Implement traceRay()

This is a little bit more work than executeCallable() because we also
have to set up the MemRay data structure which the ray traversal
hardware uses to keep its state.

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7356>
This commit is contained in:
Jason Ekstrand
2020-08-06 15:51:58 -05:00
committed by Marge Bot
parent 75209d5bd1
commit 7ce7c93755
3 changed files with 105 additions and 2 deletions

View File

@@ -106,6 +106,22 @@ lower_rt_intrinsics_impl(nir_function_impl *impl,
sysval = globals.num_dss_rt_stacks;
break;
case nir_intrinsic_load_ray_hit_sbt_addr_intel:
sysval = globals.hit_sbt_addr;
break;
case nir_intrinsic_load_ray_hit_sbt_stride_intel:
sysval = globals.hit_sbt_stride;
break;
case nir_intrinsic_load_ray_miss_sbt_addr_intel:
sysval = globals.miss_sbt_addr;
break;
case nir_intrinsic_load_ray_miss_sbt_stride_intel:
sysval = globals.miss_sbt_stride;
break;
case nir_intrinsic_load_callable_sbt_addr_intel:
sysval = globals.call_sbt_addr;
break;

View File

@@ -245,6 +245,10 @@ can_remat_instr(nir_instr *instr, struct bitset *remat)
case nir_intrinsic_load_ray_hw_stack_size_intel:
case nir_intrinsic_load_ray_sw_stack_size_intel:
case nir_intrinsic_load_ray_num_dss_rt_stacks_intel:
case nir_intrinsic_load_ray_hit_sbt_addr_intel:
case nir_intrinsic_load_ray_hit_sbt_stride_intel:
case nir_intrinsic_load_ray_miss_sbt_addr_intel:
case nir_intrinsic_load_ray_miss_sbt_stride_intel:
case nir_intrinsic_load_callable_sbt_addr_intel:
case nir_intrinsic_load_callable_sbt_stride_intel:
/* Notably missing from the above list is btd_local_arg_addr_intel.
@@ -529,8 +533,87 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
/* Lower to the _intel intrinsic */
switch (call->intrinsic) {
case nir_intrinsic_trace_ray:
unreachable("TODO");
case nir_intrinsic_trace_ray: {
nir_ssa_def *as_addr = call->src[0].ssa;
nir_ssa_def *ray_flags = call->src[1].ssa;
/* From the SPIR-V spec:
*
* "Only the 8 least-significant bits of Cull Mask are used by
* this instruction - other bits are ignored.
*
* Only the 4 least-significant bits of SBT Offset and SBT
* Stride are used by this instruction - other bits are
* ignored.
*
* Only the 16 least-significant bits of Miss Index are used by
* this instruction - other bits are ignored."
*/
nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff);
nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf);
nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf);
nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff);
nir_ssa_def *ray_orig = call->src[6].ssa;
nir_ssa_def *ray_t_min = call->src[7].ssa;
nir_ssa_def *ray_dir = call->src[8].ssa;
nir_ssa_def *ray_t_max = call->src[9].ssa;
/* The hardware packet takes the address to the root node in the
* acceleration structure, not the acceleration structure itself.
* To find that, we have to read the root node offset from the
* acceleration structure which is the first QWord.
*/
nir_ssa_def *root_node_ptr =
nir_iadd(b, as_addr, nir_load_global(b, as_addr, 256, 1, 64));
/* The hardware packet requires an address to the first element of
* the hit SBT.
*
* In order to calculate this, we must multiply the "SBT Offset"
* provided to OpTraceRay by the SBT stride provided for the hit
* SBT in the call to vkCmdTraceRay() and add that to the base
* address of the hit SBT. This stride is not to be confused with
* the "SBT Stride" provided to OpTraceRay which is in units of
* this stride. It's a rather terrible overload of the word
* "stride". The hardware docs calls the SPIR-V stride value the
* "shader index multiplier" which is a much more sane name.
*/
nir_ssa_def *hit_sbt_stride_B =
nir_load_ray_hit_sbt_stride_intel(b);
nir_ssa_def *hit_sbt_offset_B =
nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B));
nir_ssa_def *hit_sbt_addr =
nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b),
nir_u2u64(b, hit_sbt_offset_B));
/* The hardware packet takes an address to the miss BSR. */
nir_ssa_def *miss_sbt_stride_B =
nir_load_ray_miss_sbt_stride_intel(b);
nir_ssa_def *miss_sbt_offset_B =
nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B));
nir_ssa_def *miss_sbt_addr =
nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b),
nir_u2u64(b, miss_sbt_offset_B));
struct brw_nir_rt_mem_ray_defs ray_defs = {
.root_node_ptr = root_node_ptr,
.ray_flags = nir_u2u16(b, ray_flags),
.ray_mask = cull_mask,
.hit_group_sr_base_ptr = hit_sbt_addr,
.hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
.miss_sr_ptr = miss_sbt_addr,
.orig = ray_orig,
.t_near = ray_t_min,
.dir = ray_dir,
.t_far = ray_t_max,
.shader_index_multiplier = sbt_stride,
};
brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
nir_intrinsic_instr *ray_intel =
nir_intrinsic_instr_create(b->shader,
nir_intrinsic_trace_ray_initial_intel);
nir_builder_instr_insert(b, &ray_intel->instr);
break;
}
case nir_intrinsic_report_ray_intersection:
unreachable("Any-hit shaders must be inlined");