diff --git a/src/amd/vulkan/radv_nir_lower_ray_queries.c b/src/amd/vulkan/radv_nir_lower_ray_queries.c index e6786b09bcd..579037984cc 100644 --- a/src/amd/vulkan/radv_nir_lower_ray_queries.c +++ b/src/amd/vulkan/radv_nir_lower_ray_queries.c @@ -34,7 +34,8 @@ /* Traversal stack size. Traversal supports backtracking so we can go deeper than this size if * needed. However, we keep a large stack size to avoid it being put into registers, which hurts * occupancy. */ -#define MAX_STACK_ENTRY_COUNT 76 +#define MAX_SCRATCH_STACK_ENTRY_COUNT 76 +#define MAX_SHARED_STACK_ENTRY_COUNT 8 typedef struct { nir_variable *variable; @@ -176,6 +177,7 @@ struct ray_query_vars { struct ray_query_traversal_vars trav; rq_variable *stack; + uint32_t shared_base; }; #define VAR_NAME(name) \ @@ -244,7 +246,7 @@ init_ray_query_intersection_vars(void *ctx, nir_shader *shader, unsigned array_l static void init_ray_query_vars(nir_shader *shader, unsigned array_length, struct ray_query_vars *dst, - const char *base_name) + const char *base_name, uint32_t max_shared_size) { void *ctx = dst; const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3); @@ -268,16 +270,27 @@ init_ray_query_vars(nir_shader *shader, unsigned array_length, struct ray_query_ dst->trav = init_ray_query_traversal_vars(dst, shader, array_length, VAR_NAME("_top")); - dst->stack = rq_variable_create(dst, shader, array_length, - glsl_array_type(glsl_uint_type(), MAX_STACK_ENTRY_COUNT, - glsl_get_explicit_stride(glsl_uint_type())), - VAR_NAME("_stack")); + uint32_t workgroup_size = shader->info.workgroup_size[0] * shader->info.workgroup_size[1] * + shader->info.workgroup_size[2]; + uint32_t shared_stack_size = workgroup_size * MAX_SHARED_STACK_ENTRY_COUNT * 4; + uint32_t shared_offset = align(shader->info.shared_size, 4); + if (shader->info.stage != MESA_SHADER_COMPUTE || array_length > 1 || + shared_offset + shared_stack_size > max_shared_size) { + dst->stack = rq_variable_create( + dst, shader, array_length, + glsl_array_type(glsl_uint_type(), MAX_SCRATCH_STACK_ENTRY_COUNT, 0), VAR_NAME("_stack")); + } else { + dst->stack = NULL; + dst->shared_base = shared_offset; + shader->info.shared_size = shared_offset + shared_stack_size; + } } #undef VAR_NAME static void -lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table *ht) +lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table *ht, + uint32_t max_shared_size) { struct ray_query_vars *vars = ralloc(ht, struct ray_query_vars); @@ -285,7 +298,8 @@ lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table * if (glsl_type_is_array(ray_query->type)) array_length = glsl_get_length(ray_query->type); - init_ray_query_vars(shader, array_length, vars, ray_query->name == NULL ? "" : ray_query->name); + init_ray_query_vars(shader, array_length, vars, ray_query->name == NULL ? "" : ray_query->name, + max_shared_size); _mesa_hash_table_insert(ht, ray_query, vars); } @@ -385,7 +399,17 @@ lower_rq_initialize(nir_builder *b, nir_ssa_def *index, nir_intrinsic_instr *ins rq_store_var(b, index, vars->root_bvh_base, bvh_base, 0x1); rq_store_var(b, index, vars->trav.bvh_base, bvh_base, 1); - rq_store_var(b, index, vars->trav.stack, nir_imm_int(b, 0), 0x1); + if (vars->stack) { + rq_store_var(b, index, vars->trav.stack, nir_imm_int(b, 0), 0x1); + rq_store_var(b, index, vars->trav.stack_base, nir_imm_int(b, 0), 0x1); + } else { + nir_ssa_def *base_offset = + nir_imul_imm(b, nir_load_local_invocation_index(b), sizeof(uint32_t)); + base_offset = nir_iadd_imm(b, base_offset, vars->shared_base); + rq_store_var(b, index, vars->trav.stack, base_offset, 0x1); + rq_store_var(b, index, vars->trav.stack_base, base_offset, 0x1); + } + rq_store_var(b, index, vars->trav.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1); rq_store_var(b, index, vars->trav.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1); rq_store_var(b, index, vars->trav.instance_top_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), @@ -393,7 +417,6 @@ lower_rq_initialize(nir_builder *b, nir_ssa_def *index, nir_intrinsic_instr *ins rq_store_var(b, index, vars->trav.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 0x1); rq_store_var(b, index, vars->trav.top_stack, nir_imm_int(b, -1), 1); - rq_store_var(b, index, vars->trav.stack_base, nir_imm_int(b, 0), 1); } nir_push_else(b, NULL); { @@ -614,14 +637,20 @@ store_stack_entry(nir_builder *b, nir_ssa_def *index, nir_ssa_def *value, const struct radv_ray_traversal_args *args) { struct traversal_data *data = args->data; - rq_store_array(b, data->index, data->vars->stack, index, value, 1); + if (data->vars->stack) + rq_store_array(b, data->index, data->vars->stack, index, value, 1); + else + nir_store_shared(b, value, index, .base = 0, .align_mul = 4); } static nir_ssa_def * load_stack_entry(nir_builder *b, nir_ssa_def *index, const struct radv_ray_traversal_args *args) { struct traversal_data *data = args->data; - return rq_load_array(b, data->index, data->vars->stack, index); + if (data->vars->stack) + return rq_load_array(b, data->index, data->vars->stack, index); + else + return nir_load_shared(b, 1, 32, index, .base = 0, .align_mul = 4); } static nir_ssa_def * @@ -658,8 +687,6 @@ lower_rq_proceed(nir_builder *b, nir_ssa_def *index, struct ray_query_vars *vars .tmin = rq_load_var(b, index, vars->tmin), .dir = rq_load_var(b, index, vars->direction), .vars = trav_vars, - .stack_stride = 1, - .stack_entries = MAX_STACK_ENTRY_COUNT, .stack_store_cb = store_stack_entry, .stack_load_cb = load_stack_entry, .aabb_cb = handle_candidate_aabb, @@ -667,6 +694,17 @@ lower_rq_proceed(nir_builder *b, nir_ssa_def *index, struct ray_query_vars *vars .data = &data, }; + if (vars->stack) { + args.stack_stride = 1; + args.stack_entries = MAX_SCRATCH_STACK_ENTRY_COUNT; + } else { + uint32_t workgroup_size = b->shader->info.workgroup_size[0] * + b->shader->info.workgroup_size[1] * + b->shader->info.workgroup_size[2]; + args.stack_stride = workgroup_size * 4; + args.stack_entries = MAX_SHARED_STACK_ENTRY_COUNT; + } + nir_push_if(b, rq_load_var(b, index, vars->incomplete)); { nir_ssa_def *incomplete = radv_build_ray_traversal(device, b, &args); @@ -695,7 +733,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device if (!var->data.ray_query) continue; - lower_ray_query(shader, var, query_ht); + lower_ray_query(shader, var, query_ht, device->physical_device->max_shared_size); contains_ray_query = true; } @@ -710,7 +748,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device if (!var->data.ray_query) continue; - lower_ray_query(shader, var, query_ht); + lower_ray_query(shader, var, query_ht, device->physical_device->max_shared_size); contains_ray_query = true; } diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index 8a6e7c0419f..6f7b6284948 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -977,6 +977,16 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_pipeline_ nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir)); if (nir->info.ray_queries > 0) { + /* Lower shared variables early to prevent the over allocation of shared memory in + * radv_nir_lower_ray_queries. */ + if (nir->info.stage == MESA_SHADER_COMPUTE) { + if (!nir->info.shared_memory_explicit_layout) + NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_mem_shared, shared_var_info); + + NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_mem_shared, + nir_address_format_32bit_offset); + } + NIR_PASS(_, nir, nir_opt_ray_queries); NIR_PASS(_, nir, nir_opt_ray_query_ranges); NIR_PASS(_, nir, radv_nir_lower_ray_queries, device);