diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index a128bd92cfb..0f9a8dba576 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -1051,6 +1051,7 @@ struct rt_traversal_vars { nir_variable *hit; nir_variable *bvh_base; nir_variable *stack; + nir_variable *lds_stack_base; nir_variable *top_stack; nir_variable *current_node; }; @@ -1078,6 +1079,8 @@ init_traversal_vars(nir_builder *b) "traversal_bvh_base"); ret.stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_ptr"); + ret.lds_stack_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), + "traversal_lds_stack_base"); ret.top_stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_top_stack_ptr"); ret.current_node = @@ -1406,7 +1409,8 @@ build_traversal_shader(struct radv_device *device, b.shader->info.shared_size); nir_ssa_def *stack_exit_bound = nir_imm_int(&b, stack_entry_stride + b.shader->info.shared_size); - b.shader->info.shared_size += stack_entry_stride * MAX_STACK_ENTRY_COUNT; + const uint32_t lds_stack_size = stack_entry_stride * MAX_STACK_LDS_ENTRY_COUNT; + b.shader->info.shared_size += lds_stack_size; nir_ssa_def *accel_struct = nir_load_var(&b, vars.accel_struct); @@ -1431,6 +1435,7 @@ build_traversal_shader(struct radv_device *device, nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1); nir_store_var(&b, trav_vars.stack, stack_base, 1); + nir_store_var(&b, trav_vars.lds_stack_base, stack_base, 1); nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, 0), 1); nir_store_var(&b, trav_vars.current_node, bvh_root, 0x1); @@ -1461,10 +1466,52 @@ build_traversal_shader(struct radv_device *device, nir_store_var(&b, trav_vars.stack, nir_isub(&b, nir_load_var(&b, trav_vars.stack), stack_entry_stride_def), 1); - nir_store_var(&b, trav_vars.current_node, - nir_load_shared(&b, 1, 32, nir_load_var(&b, trav_vars.stack), .base = 0, - .align_mul = stack_entry_size), - 0x1); + nir_push_if(&b, nir_ilt(&b, nir_load_var(&b, trav_vars.stack), + nir_load_var(&b, trav_vars.lds_stack_base))); + { + nir_ssa_def *scratch_addr = nir_imul_imm( + &b, nir_udiv_imm(&b, nir_load_var(&b, trav_vars.stack), stack_entry_stride), + stack_entry_size); + nir_store_var(&b, trav_vars.current_node, nir_load_scratch(&b, 1, 32, scratch_addr), + 0x1); + nir_store_var(&b, trav_vars.lds_stack_base, nir_load_var(&b, trav_vars.stack), 0x1); + } + nir_push_else(&b, NULL); + { + nir_ssa_def *stack_ptr = + nir_umod(&b, nir_load_var(&b, trav_vars.stack), nir_imm_int(&b, lds_stack_size)); + nir_store_var( + &b, trav_vars.current_node, + nir_load_shared(&b, 1, 32, stack_ptr, .base = 0, .align_mul = stack_entry_size), + 0x1); + } + nir_pop_if(&b, NULL); + } + nir_pop_if(&b, NULL); + + nir_ssa_def *might_overflow = + nir_ige(&b, + nir_isub(&b, nir_load_var(&b, trav_vars.stack), + nir_load_var(&b, trav_vars.lds_stack_base)), + nir_imm_int(&b, (MAX_STACK_LDS_ENTRY_COUNT - 2) * stack_entry_stride)); + nir_push_if(&b, might_overflow); + { + nir_ssa_def *scratch_addr = nir_imul_imm( + &b, nir_udiv_imm(&b, nir_load_var(&b, trav_vars.lds_stack_base), stack_entry_stride), + stack_entry_size); + for (int i = 0; i < 4; ++i) { + nir_ssa_def *lds_stack_ptr = nir_umod(&b, nir_load_var(&b, trav_vars.lds_stack_base), + nir_imm_int(&b, lds_stack_size)); + + nir_ssa_def *node = + nir_load_shared(&b, 1, 32, lds_stack_ptr, .base = 0, .align_mul = stack_entry_size); + nir_store_scratch(&b, node, scratch_addr); + + nir_store_var( + &b, trav_vars.lds_stack_base, + nir_iadd(&b, nir_load_var(&b, trav_vars.lds_stack_base), stack_entry_stride_def), 1); + scratch_addr = nir_iadd_imm(&b, scratch_addr, stack_entry_size); + } } nir_pop_if(&b, NULL); @@ -1560,7 +1607,9 @@ build_traversal_shader(struct radv_device *device, nir_push_if(&b, nir_ine_imm(&b, new_nodes[i], 0xffffffff)); for (unsigned i = 4; i-- > 1;) { - nir_store_shared(&b, new_nodes[i], nir_load_var(&b, trav_vars.stack), .base = 0, + nir_ssa_def *stack_ptr = + nir_umod(&b, nir_load_var(&b, trav_vars.stack), nir_imm_int(&b, lds_stack_size)); + nir_store_shared(&b, new_nodes[i], stack_ptr, .base = 0, .align_mul = stack_entry_size); nir_store_var( &b, trav_vars.stack, @@ -1772,7 +1821,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo)) nir_store_var(&b, vars.stack_ptr, nir_load_rt_dynamic_callable_stack_base_amd(&b), 0x1); else - nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1); + nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, MAX_STACK_SCRATCH_ENTRY_COUNT * 4), 0x1); nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, true), 1); @@ -1822,11 +1871,9 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf nir_pop_loop(&b, loop); - if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo)) { - /* Put something so scratch gets enabled in the shader. */ - b.shader->scratch_size = 16; - } else - b.shader->scratch_size = compute_rt_stack_size(pCreateInfo, stack_sizes); + b.shader->scratch_size = MAX2(16, MAX_STACK_SCRATCH_ENTRY_COUNT * 4); + if (!radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo)) + b.shader->scratch_size += compute_rt_stack_size(pCreateInfo, stack_sizes); /* Deal with all the inline functions. */ nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader)); diff --git a/src/amd/vulkan/radv_rt_common.h b/src/amd/vulkan/radv_rt_common.h index 90780459ab7..77188386687 100644 --- a/src/amd/vulkan/radv_rt_common.h +++ b/src/amd/vulkan/radv_rt_common.h @@ -68,6 +68,8 @@ nir_ssa_def *create_bvh_descriptor(nir_builder *b); * + 1 instance node. Furthermore, when processing a box node, worst case we actually * push all 4 children and remove one, so the DFS stack depth is box nodes * 3 + 2. */ -#define MAX_STACK_ENTRY_COUNT 76 +#define MAX_STACK_ENTRY_COUNT 76 +#define MAX_STACK_LDS_ENTRY_COUNT 16 +#define MAX_STACK_SCRATCH_ENTRY_COUNT (MAX_STACK_ENTRY_COUNT - MAX_STACK_LDS_ENTRY_COUNT) #endif