radv/rt: Add and use radv_build_traversal
Moves most of the build code to a helper which will be useful for adding inline traversal. Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24809>
This commit is contained in:

committed by
Marge Bot

parent
2d7965dbff
commit
774421f11e
@@ -1328,106 +1328,79 @@ load_stack_entry(nir_builder *b, nir_def *index, const struct radv_ray_traversal
|
||||
return nir_load_shared(b, 1, 32, index, .base = 0, .align_mul = 4);
|
||||
}
|
||||
|
||||
nir_shader *
|
||||
radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
|
||||
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const struct radv_pipeline_key *key)
|
||||
static void
|
||||
radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
|
||||
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const struct radv_pipeline_key *key,
|
||||
nir_builder *b, struct rt_variables *vars)
|
||||
{
|
||||
const VkPipelineCreateFlagBits2KHR create_flags = radv_get_pipeline_create_flags(pCreateInfo);
|
||||
|
||||
/* Create the traversal shader as an intersection shader to prevent validation failures due to
|
||||
* invalid variable modes.*/
|
||||
nir_builder b = radv_meta_init_shader(device, MESA_SHADER_INTERSECTION, "rt_traversal");
|
||||
b.shader->info.internal = false;
|
||||
b.shader->info.workgroup_size[0] = 8;
|
||||
b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
|
||||
b.shader->info.shared_size = device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
|
||||
struct rt_variables vars = create_rt_variables(b.shader, create_flags);
|
||||
|
||||
/* Register storage for hit attributes */
|
||||
nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_SIZE / sizeof(uint32_t)];
|
||||
|
||||
for (uint32_t i = 0; i < ARRAY_SIZE(hit_attribs); i++)
|
||||
hit_attribs[i] = nir_local_variable_create(nir_shader_get_entrypoint(b.shader), glsl_uint_type(), "ahit_attrib");
|
||||
|
||||
nir_variable *barycentrics =
|
||||
nir_variable_create(b.shader, nir_var_ray_hit_attrib, glsl_vector_type(GLSL_TYPE_FLOAT, 2), "barycentrics");
|
||||
nir_variable_create(b->shader, nir_var_ray_hit_attrib, glsl_vector_type(GLSL_TYPE_FLOAT, 2), "barycentrics");
|
||||
barycentrics->data.driver_location = 0;
|
||||
|
||||
/* initialize trace_ray arguments */
|
||||
nir_def *accel_struct = nir_load_accel_struct_amd(&b);
|
||||
nir_def *cull_mask_and_flags = nir_load_cull_mask_and_flags_amd(&b);
|
||||
nir_store_var(&b, vars.cull_mask_and_flags, cull_mask_and_flags, 0x1);
|
||||
nir_store_var(&b, vars.sbt_offset, nir_load_sbt_offset_amd(&b), 0x1);
|
||||
nir_store_var(&b, vars.sbt_stride, nir_load_sbt_stride_amd(&b), 0x1);
|
||||
nir_store_var(&b, vars.origin, nir_load_ray_world_origin(&b), 0x7);
|
||||
nir_store_var(&b, vars.tmin, nir_load_ray_t_min(&b), 0x1);
|
||||
nir_store_var(&b, vars.direction, nir_load_ray_world_direction(&b), 0x7);
|
||||
nir_store_var(&b, vars.tmax, nir_load_ray_t_max(&b), 0x1);
|
||||
nir_store_var(&b, vars.arg, nir_load_rt_arg_scratch_offset_amd(&b), 0x1);
|
||||
nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
|
||||
struct rt_traversal_vars trav_vars = init_traversal_vars(b);
|
||||
|
||||
struct rt_traversal_vars trav_vars = init_traversal_vars(&b);
|
||||
|
||||
nir_store_var(&b, trav_vars.hit, nir_imm_false(&b), 1);
|
||||
nir_store_var(b, trav_vars.hit, nir_imm_false(b), 1);
|
||||
|
||||
nir_def *accel_struct = nir_load_var(b, vars->accel_struct);
|
||||
nir_def *bvh_offset = nir_build_load_global(
|
||||
&b, 1, 32, nir_iadd_imm(&b, accel_struct, offsetof(struct radv_accel_struct_header, bvh_offset)),
|
||||
b, 1, 32, nir_iadd_imm(b, accel_struct, offsetof(struct radv_accel_struct_header, bvh_offset)),
|
||||
.access = ACCESS_NON_WRITEABLE);
|
||||
nir_def *root_bvh_base = nir_iadd(&b, accel_struct, nir_u2u64(&b, bvh_offset));
|
||||
root_bvh_base = build_addr_to_node(&b, root_bvh_base);
|
||||
nir_def *root_bvh_base = nir_iadd(b, accel_struct, nir_u2u64(b, bvh_offset));
|
||||
root_bvh_base = build_addr_to_node(b, root_bvh_base);
|
||||
|
||||
nir_store_var(&b, trav_vars.bvh_base, root_bvh_base, 1);
|
||||
nir_store_var(b, trav_vars.bvh_base, root_bvh_base, 1);
|
||||
|
||||
nir_def *vec3ones = nir_imm_vec3(&b, 1.0, 1.0, 1.0);
|
||||
nir_def *vec3ones = nir_imm_vec3(b, 1.0, 1.0, 1.0);
|
||||
|
||||
nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7);
|
||||
nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7);
|
||||
nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7);
|
||||
nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_imm_int(&b, 0), 1);
|
||||
nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1);
|
||||
nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7);
|
||||
nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7);
|
||||
nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
|
||||
nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_imm_int(b, 0), 1);
|
||||
nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1);
|
||||
|
||||
nir_store_var(&b, trav_vars.stack, nir_imul_imm(&b, nir_load_local_invocation_index(&b), sizeof(uint32_t)), 1);
|
||||
nir_store_var(&b, trav_vars.stack_low_watermark, nir_load_var(&b, trav_vars.stack), 1);
|
||||
nir_store_var(&b, trav_vars.current_node, nir_imm_int(&b, RADV_BVH_ROOT_NODE), 0x1);
|
||||
nir_store_var(&b, trav_vars.previous_node, nir_imm_int(&b, RADV_BVH_INVALID_NODE), 0x1);
|
||||
nir_store_var(&b, trav_vars.instance_top_node, nir_imm_int(&b, RADV_BVH_INVALID_NODE), 0x1);
|
||||
nir_store_var(&b, trav_vars.instance_bottom_node, nir_imm_int(&b, RADV_BVH_NO_INSTANCE_ROOT), 0x1);
|
||||
nir_store_var(b, trav_vars.stack, nir_imul_imm(b, nir_load_local_invocation_index(b), sizeof(uint32_t)), 1);
|
||||
nir_store_var(b, trav_vars.stack_low_watermark, nir_load_var(b, trav_vars.stack), 1);
|
||||
nir_store_var(b, trav_vars.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1);
|
||||
nir_store_var(b, trav_vars.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
|
||||
nir_store_var(b, trav_vars.instance_top_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
|
||||
nir_store_var(b, trav_vars.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 0x1);
|
||||
|
||||
nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, -1), 1);
|
||||
nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, -1), 1);
|
||||
|
||||
struct radv_ray_traversal_vars trav_vars_args = {
|
||||
.tmax = nir_build_deref_var(&b, vars.tmax),
|
||||
.origin = nir_build_deref_var(&b, trav_vars.origin),
|
||||
.dir = nir_build_deref_var(&b, trav_vars.dir),
|
||||
.inv_dir = nir_build_deref_var(&b, trav_vars.inv_dir),
|
||||
.bvh_base = nir_build_deref_var(&b, trav_vars.bvh_base),
|
||||
.stack = nir_build_deref_var(&b, trav_vars.stack),
|
||||
.top_stack = nir_build_deref_var(&b, trav_vars.top_stack),
|
||||
.stack_low_watermark = nir_build_deref_var(&b, trav_vars.stack_low_watermark),
|
||||
.current_node = nir_build_deref_var(&b, trav_vars.current_node),
|
||||
.previous_node = nir_build_deref_var(&b, trav_vars.previous_node),
|
||||
.instance_top_node = nir_build_deref_var(&b, trav_vars.instance_top_node),
|
||||
.instance_bottom_node = nir_build_deref_var(&b, trav_vars.instance_bottom_node),
|
||||
.instance_addr = nir_build_deref_var(&b, trav_vars.instance_addr),
|
||||
.sbt_offset_and_flags = nir_build_deref_var(&b, trav_vars.sbt_offset_and_flags),
|
||||
.tmax = nir_build_deref_var(b, vars->tmax),
|
||||
.origin = nir_build_deref_var(b, trav_vars.origin),
|
||||
.dir = nir_build_deref_var(b, trav_vars.dir),
|
||||
.inv_dir = nir_build_deref_var(b, trav_vars.inv_dir),
|
||||
.bvh_base = nir_build_deref_var(b, trav_vars.bvh_base),
|
||||
.stack = nir_build_deref_var(b, trav_vars.stack),
|
||||
.top_stack = nir_build_deref_var(b, trav_vars.top_stack),
|
||||
.stack_low_watermark = nir_build_deref_var(b, trav_vars.stack_low_watermark),
|
||||
.current_node = nir_build_deref_var(b, trav_vars.current_node),
|
||||
.previous_node = nir_build_deref_var(b, trav_vars.previous_node),
|
||||
.instance_top_node = nir_build_deref_var(b, trav_vars.instance_top_node),
|
||||
.instance_bottom_node = nir_build_deref_var(b, trav_vars.instance_bottom_node),
|
||||
.instance_addr = nir_build_deref_var(b, trav_vars.instance_addr),
|
||||
.sbt_offset_and_flags = nir_build_deref_var(b, trav_vars.sbt_offset_and_flags),
|
||||
};
|
||||
|
||||
struct traversal_data data = {
|
||||
.device = device,
|
||||
.vars = &vars,
|
||||
.vars = vars,
|
||||
.trav_vars = &trav_vars,
|
||||
.barycentrics = barycentrics,
|
||||
.pipeline = pipeline,
|
||||
.key = key,
|
||||
};
|
||||
|
||||
nir_def *cull_mask_and_flags = nir_load_var(b, vars->cull_mask_and_flags);
|
||||
struct radv_ray_traversal_args args = {
|
||||
.root_bvh_base = root_bvh_base,
|
||||
.flags = cull_mask_and_flags,
|
||||
.cull_mask = cull_mask_and_flags,
|
||||
.origin = nir_load_var(&b, vars.origin),
|
||||
.tmin = nir_load_var(&b, vars.tmin),
|
||||
.dir = nir_load_var(&b, vars.direction),
|
||||
.origin = nir_load_var(b, vars->origin),
|
||||
.tmin = nir_load_var(b, vars->tmin),
|
||||
.dir = nir_load_var(b, vars->direction),
|
||||
.vars = trav_vars_args,
|
||||
.stack_stride = device->physical_device->rt_wave_size * sizeof(uint32_t),
|
||||
.stack_entries = MAX_STACK_ENTRY_COUNT,
|
||||
@@ -1443,28 +1416,65 @@ radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_
|
||||
.data = &data,
|
||||
};
|
||||
|
||||
radv_build_ray_traversal(device, &b, &args);
|
||||
radv_build_ray_traversal(device, b, &args);
|
||||
|
||||
nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
|
||||
lower_hit_attrib_derefs(b.shader);
|
||||
lower_hit_attribs(b.shader, hit_attribs, device->physical_device->rt_wave_size);
|
||||
nir_metadata_preserve(nir_shader_get_entrypoint(b->shader), nir_metadata_none);
|
||||
lower_hit_attrib_derefs(b->shader);
|
||||
|
||||
/* Register storage for hit attributes */
|
||||
nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_SIZE / sizeof(uint32_t)];
|
||||
|
||||
for (uint32_t i = 0; i < ARRAY_SIZE(hit_attribs); i++)
|
||||
hit_attribs[i] = nir_local_variable_create(nir_shader_get_entrypoint(b->shader), glsl_uint_type(), "ahit_attrib");
|
||||
|
||||
lower_hit_attribs(b->shader, hit_attribs, device->physical_device->rt_wave_size);
|
||||
|
||||
/* Initialize follow-up shader. */
|
||||
nir_push_if(&b, nir_load_var(&b, trav_vars.hit));
|
||||
nir_push_if(b, nir_load_var(b, trav_vars.hit));
|
||||
{
|
||||
for (int i = 0; i < ARRAY_SIZE(hit_attribs); ++i)
|
||||
nir_store_hit_attrib_amd(&b, nir_load_var(&b, hit_attribs[i]), .base = i);
|
||||
nir_execute_closest_hit_amd(&b, nir_load_var(&b, vars.idx), nir_load_var(&b, vars.tmax),
|
||||
nir_load_var(&b, vars.primitive_id), nir_load_var(&b, vars.instance_addr),
|
||||
nir_load_var(&b, vars.geometry_id_and_flags), nir_load_var(&b, vars.hit_kind));
|
||||
nir_store_hit_attrib_amd(b, nir_load_var(b, hit_attribs[i]), .base = i);
|
||||
nir_execute_closest_hit_amd(b, nir_load_var(b, vars->idx), nir_load_var(b, vars->tmax),
|
||||
nir_load_var(b, vars->primitive_id), nir_load_var(b, vars->instance_addr),
|
||||
nir_load_var(b, vars->geometry_id_and_flags), nir_load_var(b, vars->hit_kind));
|
||||
}
|
||||
nir_push_else(&b, NULL);
|
||||
nir_push_else(b, NULL);
|
||||
{
|
||||
/* Only load the miss shader if we actually miss. It is valid to not specify an SBT pointer
|
||||
* for miss shaders if none of the rays miss. */
|
||||
nir_execute_miss_amd(&b, nir_load_var(&b, vars.tmax));
|
||||
nir_execute_miss_amd(b, nir_load_var(b, vars->tmax));
|
||||
}
|
||||
nir_pop_if(&b, NULL);
|
||||
nir_pop_if(b, NULL);
|
||||
}
|
||||
|
||||
nir_shader *
|
||||
radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
|
||||
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const struct radv_pipeline_key *key)
|
||||
{
|
||||
const VkPipelineCreateFlagBits2KHR create_flags = radv_get_pipeline_create_flags(pCreateInfo);
|
||||
|
||||
/* Create the traversal shader as an intersection shader to prevent validation failures due to
|
||||
* invalid variable modes.*/
|
||||
nir_builder b = radv_meta_init_shader(device, MESA_SHADER_INTERSECTION, "rt_traversal");
|
||||
b.shader->info.internal = false;
|
||||
b.shader->info.workgroup_size[0] = 8;
|
||||
b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
|
||||
b.shader->info.shared_size = device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
|
||||
struct rt_variables vars = create_rt_variables(b.shader, create_flags);
|
||||
|
||||
/* initialize trace_ray arguments */
|
||||
nir_store_var(&b, vars.accel_struct, nir_load_accel_struct_amd(&b), 1);
|
||||
nir_store_var(&b, vars.cull_mask_and_flags, nir_load_cull_mask_and_flags_amd(&b), 0x1);
|
||||
nir_store_var(&b, vars.sbt_offset, nir_load_sbt_offset_amd(&b), 0x1);
|
||||
nir_store_var(&b, vars.sbt_stride, nir_load_sbt_stride_amd(&b), 0x1);
|
||||
nir_store_var(&b, vars.origin, nir_load_ray_world_origin(&b), 0x7);
|
||||
nir_store_var(&b, vars.tmin, nir_load_ray_t_min(&b), 0x1);
|
||||
nir_store_var(&b, vars.direction, nir_load_ray_world_direction(&b), 0x7);
|
||||
nir_store_var(&b, vars.tmax, nir_load_ray_t_max(&b), 0x1);
|
||||
nir_store_var(&b, vars.arg, nir_load_rt_arg_scratch_offset_amd(&b), 0x1);
|
||||
nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
|
||||
|
||||
radv_build_traversal(device, pipeline, pCreateInfo, key, &b, &vars);
|
||||
|
||||
/* Deal with all the inline functions. */
|
||||
nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
|
||||
|
Reference in New Issue
Block a user