radv/rt: Add and use radv_build_traversal

Moves most of the build code to a helper which will be useful for adding
inline traversal.

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24809>
This commit is contained in:
Konstantin Seurer
2023-06-24 16:04:52 +02:00
committed by Marge Bot
parent 2d7965dbff
commit 774421f11e

View File

@@ -1328,106 +1328,79 @@ load_stack_entry(nir_builder *b, nir_def *index, const struct radv_ray_traversal
return nir_load_shared(b, 1, 32, index, .base = 0, .align_mul = 4);
}
nir_shader *
radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const struct radv_pipeline_key *key)
static void
radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const struct radv_pipeline_key *key,
nir_builder *b, struct rt_variables *vars)
{
const VkPipelineCreateFlagBits2KHR create_flags = radv_get_pipeline_create_flags(pCreateInfo);
/* Create the traversal shader as an intersection shader to prevent validation failures due to
* invalid variable modes.*/
nir_builder b = radv_meta_init_shader(device, MESA_SHADER_INTERSECTION, "rt_traversal");
b.shader->info.internal = false;
b.shader->info.workgroup_size[0] = 8;
b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
b.shader->info.shared_size = device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
struct rt_variables vars = create_rt_variables(b.shader, create_flags);
/* Register storage for hit attributes */
nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_SIZE / sizeof(uint32_t)];
for (uint32_t i = 0; i < ARRAY_SIZE(hit_attribs); i++)
hit_attribs[i] = nir_local_variable_create(nir_shader_get_entrypoint(b.shader), glsl_uint_type(), "ahit_attrib");
nir_variable *barycentrics =
nir_variable_create(b.shader, nir_var_ray_hit_attrib, glsl_vector_type(GLSL_TYPE_FLOAT, 2), "barycentrics");
nir_variable_create(b->shader, nir_var_ray_hit_attrib, glsl_vector_type(GLSL_TYPE_FLOAT, 2), "barycentrics");
barycentrics->data.driver_location = 0;
/* initialize trace_ray arguments */
nir_def *accel_struct = nir_load_accel_struct_amd(&b);
nir_def *cull_mask_and_flags = nir_load_cull_mask_and_flags_amd(&b);
nir_store_var(&b, vars.cull_mask_and_flags, cull_mask_and_flags, 0x1);
nir_store_var(&b, vars.sbt_offset, nir_load_sbt_offset_amd(&b), 0x1);
nir_store_var(&b, vars.sbt_stride, nir_load_sbt_stride_amd(&b), 0x1);
nir_store_var(&b, vars.origin, nir_load_ray_world_origin(&b), 0x7);
nir_store_var(&b, vars.tmin, nir_load_ray_t_min(&b), 0x1);
nir_store_var(&b, vars.direction, nir_load_ray_world_direction(&b), 0x7);
nir_store_var(&b, vars.tmax, nir_load_ray_t_max(&b), 0x1);
nir_store_var(&b, vars.arg, nir_load_rt_arg_scratch_offset_amd(&b), 0x1);
nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
struct rt_traversal_vars trav_vars = init_traversal_vars(b);
struct rt_traversal_vars trav_vars = init_traversal_vars(&b);
nir_store_var(&b, trav_vars.hit, nir_imm_false(&b), 1);
nir_store_var(b, trav_vars.hit, nir_imm_false(b), 1);
nir_def *accel_struct = nir_load_var(b, vars->accel_struct);
nir_def *bvh_offset = nir_build_load_global(
&b, 1, 32, nir_iadd_imm(&b, accel_struct, offsetof(struct radv_accel_struct_header, bvh_offset)),
b, 1, 32, nir_iadd_imm(b, accel_struct, offsetof(struct radv_accel_struct_header, bvh_offset)),
.access = ACCESS_NON_WRITEABLE);
nir_def *root_bvh_base = nir_iadd(&b, accel_struct, nir_u2u64(&b, bvh_offset));
root_bvh_base = build_addr_to_node(&b, root_bvh_base);
nir_def *root_bvh_base = nir_iadd(b, accel_struct, nir_u2u64(b, bvh_offset));
root_bvh_base = build_addr_to_node(b, root_bvh_base);
nir_store_var(&b, trav_vars.bvh_base, root_bvh_base, 1);
nir_store_var(b, trav_vars.bvh_base, root_bvh_base, 1);
nir_def *vec3ones = nir_imm_vec3(&b, 1.0, 1.0, 1.0);
nir_def *vec3ones = nir_imm_vec3(b, 1.0, 1.0, 1.0);
nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7);
nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7);
nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7);
nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_imm_int(&b, 0), 1);
nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1);
nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7);
nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7);
nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_imm_int(b, 0), 1);
nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1);
nir_store_var(&b, trav_vars.stack, nir_imul_imm(&b, nir_load_local_invocation_index(&b), sizeof(uint32_t)), 1);
nir_store_var(&b, trav_vars.stack_low_watermark, nir_load_var(&b, trav_vars.stack), 1);
nir_store_var(&b, trav_vars.current_node, nir_imm_int(&b, RADV_BVH_ROOT_NODE), 0x1);
nir_store_var(&b, trav_vars.previous_node, nir_imm_int(&b, RADV_BVH_INVALID_NODE), 0x1);
nir_store_var(&b, trav_vars.instance_top_node, nir_imm_int(&b, RADV_BVH_INVALID_NODE), 0x1);
nir_store_var(&b, trav_vars.instance_bottom_node, nir_imm_int(&b, RADV_BVH_NO_INSTANCE_ROOT), 0x1);
nir_store_var(b, trav_vars.stack, nir_imul_imm(b, nir_load_local_invocation_index(b), sizeof(uint32_t)), 1);
nir_store_var(b, trav_vars.stack_low_watermark, nir_load_var(b, trav_vars.stack), 1);
nir_store_var(b, trav_vars.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1);
nir_store_var(b, trav_vars.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
nir_store_var(b, trav_vars.instance_top_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
nir_store_var(b, trav_vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 0x1);
nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, -1), 1);
nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, -1), 1);
struct radv_ray_traversal_vars trav_vars_args = {
.tmax = nir_build_deref_var(&b, vars.tmax),
.origin = nir_build_deref_var(&b, trav_vars.origin),
.dir = nir_build_deref_var(&b, trav_vars.dir),
.inv_dir = nir_build_deref_var(&b, trav_vars.inv_dir),
.bvh_base = nir_build_deref_var(&b, trav_vars.bvh_base),
.stack = nir_build_deref_var(&b, trav_vars.stack),
.top_stack = nir_build_deref_var(&b, trav_vars.top_stack),
.stack_low_watermark = nir_build_deref_var(&b, trav_vars.stack_low_watermark),
.current_node = nir_build_deref_var(&b, trav_vars.current_node),
.previous_node = nir_build_deref_var(&b, trav_vars.previous_node),
.instance_top_node = nir_build_deref_var(&b, trav_vars.instance_top_node),
.instance_bottom_node = nir_build_deref_var(&b, trav_vars.instance_bottom_node),
.instance_addr = nir_build_deref_var(&b, trav_vars.instance_addr),
.sbt_offset_and_flags = nir_build_deref_var(&b, trav_vars.sbt_offset_and_flags),
.tmax = nir_build_deref_var(b, vars->tmax),
.origin = nir_build_deref_var(b, trav_vars.origin),
.dir = nir_build_deref_var(b, trav_vars.dir),
.inv_dir = nir_build_deref_var(b, trav_vars.inv_dir),
.bvh_base = nir_build_deref_var(b, trav_vars.bvh_base),
.stack = nir_build_deref_var(b, trav_vars.stack),
.top_stack = nir_build_deref_var(b, trav_vars.top_stack),
.stack_low_watermark = nir_build_deref_var(b, trav_vars.stack_low_watermark),
.current_node = nir_build_deref_var(b, trav_vars.current_node),
.previous_node = nir_build_deref_var(b, trav_vars.previous_node),
.instance_top_node = nir_build_deref_var(b, trav_vars.instance_top_node),
.instance_bottom_node = nir_build_deref_var(b, trav_vars.instance_bottom_node),
.instance_addr = nir_build_deref_var(b, trav_vars.instance_addr),
.sbt_offset_and_flags = nir_build_deref_var(b, trav_vars.sbt_offset_and_flags),
};
struct traversal_data data = {
.device = device,
.vars = &vars,
.vars = vars,
.trav_vars = &trav_vars,
.barycentrics = barycentrics,
.pipeline = pipeline,
.key = key,
};
nir_def *cull_mask_and_flags = nir_load_var(b, vars->cull_mask_and_flags);
struct radv_ray_traversal_args args = {
.root_bvh_base = root_bvh_base,
.flags = cull_mask_and_flags,
.cull_mask = cull_mask_and_flags,
.origin = nir_load_var(&b, vars.origin),
.tmin = nir_load_var(&b, vars.tmin),
.dir = nir_load_var(&b, vars.direction),
.origin = nir_load_var(b, vars->origin),
.tmin = nir_load_var(b, vars->tmin),
.dir = nir_load_var(b, vars->direction),
.vars = trav_vars_args,
.stack_stride = device->physical_device->rt_wave_size * sizeof(uint32_t),
.stack_entries = MAX_STACK_ENTRY_COUNT,
@@ -1443,28 +1416,65 @@ radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_
.data = &data,
};
radv_build_ray_traversal(device, &b, &args);
radv_build_ray_traversal(device, b, &args);
nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
lower_hit_attrib_derefs(b.shader);
lower_hit_attribs(b.shader, hit_attribs, device->physical_device->rt_wave_size);
nir_metadata_preserve(nir_shader_get_entrypoint(b->shader), nir_metadata_none);
lower_hit_attrib_derefs(b->shader);
/* Register storage for hit attributes */
nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_SIZE / sizeof(uint32_t)];
for (uint32_t i = 0; i < ARRAY_SIZE(hit_attribs); i++)
hit_attribs[i] = nir_local_variable_create(nir_shader_get_entrypoint(b->shader), glsl_uint_type(), "ahit_attrib");
lower_hit_attribs(b->shader, hit_attribs, device->physical_device->rt_wave_size);
/* Initialize follow-up shader. */
nir_push_if(&b, nir_load_var(&b, trav_vars.hit));
nir_push_if(b, nir_load_var(b, trav_vars.hit));
{
for (int i = 0; i < ARRAY_SIZE(hit_attribs); ++i)
nir_store_hit_attrib_amd(&b, nir_load_var(&b, hit_attribs[i]), .base = i);
nir_execute_closest_hit_amd(&b, nir_load_var(&b, vars.idx), nir_load_var(&b, vars.tmax),
nir_load_var(&b, vars.primitive_id), nir_load_var(&b, vars.instance_addr),
nir_load_var(&b, vars.geometry_id_and_flags), nir_load_var(&b, vars.hit_kind));
nir_store_hit_attrib_amd(b, nir_load_var(b, hit_attribs[i]), .base = i);
nir_execute_closest_hit_amd(b, nir_load_var(b, vars->idx), nir_load_var(b, vars->tmax),
nir_load_var(b, vars->primitive_id), nir_load_var(b, vars->instance_addr),
nir_load_var(b, vars->geometry_id_and_flags), nir_load_var(b, vars->hit_kind));
}
nir_push_else(&b, NULL);
nir_push_else(b, NULL);
{
/* Only load the miss shader if we actually miss. It is valid to not specify an SBT pointer
* for miss shaders if none of the rays miss. */
nir_execute_miss_amd(&b, nir_load_var(&b, vars.tmax));
nir_execute_miss_amd(b, nir_load_var(b, vars->tmax));
}
nir_pop_if(&b, NULL);
nir_pop_if(b, NULL);
}
nir_shader *
radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const struct radv_pipeline_key *key)
{
const VkPipelineCreateFlagBits2KHR create_flags = radv_get_pipeline_create_flags(pCreateInfo);
/* Create the traversal shader as an intersection shader to prevent validation failures due to
* invalid variable modes.*/
nir_builder b = radv_meta_init_shader(device, MESA_SHADER_INTERSECTION, "rt_traversal");
b.shader->info.internal = false;
b.shader->info.workgroup_size[0] = 8;
b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
b.shader->info.shared_size = device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
struct rt_variables vars = create_rt_variables(b.shader, create_flags);
/* initialize trace_ray arguments */
nir_store_var(&b, vars.accel_struct, nir_load_accel_struct_amd(&b), 1);
nir_store_var(&b, vars.cull_mask_and_flags, nir_load_cull_mask_and_flags_amd(&b), 0x1);
nir_store_var(&b, vars.sbt_offset, nir_load_sbt_offset_amd(&b), 0x1);
nir_store_var(&b, vars.sbt_stride, nir_load_sbt_stride_amd(&b), 0x1);
nir_store_var(&b, vars.origin, nir_load_ray_world_origin(&b), 0x7);
nir_store_var(&b, vars.tmin, nir_load_ray_t_min(&b), 0x1);
nir_store_var(&b, vars.direction, nir_load_ray_world_direction(&b), 0x7);
nir_store_var(&b, vars.tmax, nir_load_ray_t_max(&b), 0x1);
nir_store_var(&b, vars.arg, nir_load_rt_arg_scratch_offset_amd(&b), 0x1);
nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
radv_build_traversal(device, pipeline, pCreateInfo, key, &b, &vars);
/* Deal with all the inline functions. */
nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));