radv/rt: Lower ray payloads to registers

This should allow for cross stage optimizations and it reduces latency
caused by scratch access.

Totals from 44 (9.69% of 454) affected shaders:
MaxWaves: 432 -> 436 (+0.93%)
Instrs: 2740662 -> 1610327 (-41.24%); split: -41.24%, +0.00%
CodeSize: 14616932 -> 8573620 (-41.34%)
VGPRs: 4880 -> 4816 (-1.31%)
SpillSGPRs: 464 -> 294 (-36.64%)
Latency: 18548886 -> 11465281 (-38.19%); split: -38.19%, +0.00%
InvThroughput: 5195964 -> 3066729 (-40.98%); split: -40.98%, +0.00%
VClause: 99672 -> 55611 (-44.21%)
SClause: 65827 -> 38697 (-41.21%)
Copies: 231231 -> 137676 (-40.46%); split: -40.47%, +0.01%
Branches: 111379 -> 65865 (-40.86%); split: -40.87%, +0.00%
PreSGPRs: 3854 -> 3812 (-1.09%); split: -1.19%, +0.10%
PreVGPRs: 4518 -> 4439 (-1.75%); split: -1.84%, +0.09%

Reviewed-by: Friedrich Vock <friedrich.vock@gmx.de>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26431>
This commit is contained in:
Konstantin Seurer
2023-11-15 15:03:09 +01:00
committed by Marge Bot
parent 8e6d28f473
commit 658ce711d5
3 changed files with 72 additions and 26 deletions

View File

@@ -187,6 +187,7 @@ lower_rt_derefs(nir_shader *shader)
struct rt_variables {
struct radv_device *device;
const VkPipelineCreateFlags2KHR flags;
bool monolithic;
/* idx of the next shader to run in the next iteration of the main loop.
* During traversal, idx is used to store the SBT index and will contain
@@ -230,11 +231,13 @@ struct rt_variables {
};
static struct rt_variables
create_rt_variables(nir_shader *shader, struct radv_device *device, const VkPipelineCreateFlags2KHR flags)
create_rt_variables(nir_shader *shader, struct radv_device *device, const VkPipelineCreateFlags2KHR flags,
bool monolithic)
{
struct rt_variables vars = {
.device = device,
.flags = flags,
.monolithic = monolithic,
};
vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
vars.shader_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_addr");
@@ -790,7 +793,7 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni
nir_opt_dead_cf(shader);
struct rt_variables src_vars = create_rt_variables(shader, vars->device, vars->flags);
struct rt_variables src_vars = create_rt_variables(shader, vars->device, vars->flags, vars->monolithic);
map_rt_variables(var_remap, &src_vars, vars);
NIR_PASS_V(shader, lower_rt_instructions, &src_vars, false);
@@ -800,6 +803,23 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni
inline_constants(b->shader, shader);
if (vars->monolithic) {
/* Work around vulkancts declaring an incompatible dummy payload. */
nir_remove_dead_variables(shader, nir_var_shader_call_data, NULL);
nir_foreach_variable_in_shader (var, shader) {
if (var->data.mode != nir_var_shader_call_data)
continue;
/* There can only be one shader_call_data variable which has to match the caller payload. */
if (var->type != vars->arg->type)
return;
_mesa_hash_table_insert(var_remap, var, vars->arg);
break;
}
}
nir_push_if(b, nir_ieq_imm(b, idx, call_idx));
nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
nir_pop_if(b, NULL);
@@ -807,25 +827,17 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni
ralloc_free(var_remap);
}
nir_shader *
radv_parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo,
const struct radv_pipeline_key *key, const struct radv_pipeline_layout *pipeline_layout)
void
radv_nir_lower_rt_io(nir_shader *nir, bool monolithic)
{
struct radv_shader_stage rt_stage;
if (!monolithic) {
NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data,
glsl_get_natural_size_align_bytes);
radv_pipeline_stage_init(sinfo, pipeline_layout, &rt_stage);
NIR_PASS(_, nir, lower_rt_derefs);
nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, key, false);
NIR_PASS(_, shader, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data,
glsl_get_natural_size_align_bytes);
NIR_PASS(_, shader, lower_rt_derefs);
NIR_PASS(_, shader, radv_nir_lower_hit_attrib_derefs);
NIR_PASS(_, shader, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset);
return shader;
NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset);
}
}
static nir_function_impl *
@@ -1123,6 +1135,8 @@ radv_build_ahit_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_g
radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir);
assert(nir_stage);
radv_nir_lower_rt_io(nir_stage, data->vars->monolithic);
insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.any_hit_index);
ralloc_free(nir_stage);
}
@@ -1145,12 +1159,16 @@ radv_build_isec_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_g
radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->intersection_shader].nir);
assert(nir_stage);
radv_nir_lower_rt_io(nir_stage, data->vars->monolithic);
nir_shader *any_hit_stage = NULL;
if (group->any_hit_shader != VK_SHADER_UNUSED_KHR) {
any_hit_stage =
radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir);
assert(any_hit_stage);
radv_nir_lower_rt_io(any_hit_stage, data->vars->monolithic);
/* reserve stack size for any_hit before it is inlined */
data->pipeline->stages[group->any_hit_shader].stack_size = any_hit_stage->scratch_size;
@@ -1193,6 +1211,8 @@ radv_build_recursive_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_trac
radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->recursive_shader].nir);
assert(nir_stage);
radv_nir_lower_rt_io(nir_stage, data->vars->monolithic);
insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.general_index);
ralloc_free(nir_stage);
}
@@ -1514,7 +1534,7 @@ radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_
b.shader->info.workgroup_size[0] = 8;
b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
b.shader->info.shared_size = device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
struct rt_variables vars = create_rt_variables(b.shader, device, create_flags);
struct rt_variables vars = create_rt_variables(b.shader, device, create_flags, false);
/* initialize trace_ray arguments */
nir_store_var(&b, vars.accel_struct, nir_load_accel_struct_amd(&b), 1);
@@ -1567,7 +1587,9 @@ lower_rt_instruction_monolithic(nir_builder *b, nir_instr *instr, void *data)
case nir_intrinsic_execute_callable:
unreachable("nir_intrinsic_execute_callable");
case nir_intrinsic_trace_ray: {
nir_store_var(b, vars->arg, nir_iadd_imm(b, intr->src[10].ssa, -b->shader->scratch_size), 1);
nir_deref_instr *payload_deref = nir_src_as_deref(intr->src[10]);
assert(payload_deref->deref_type == nir_deref_type_var);
vars->arg = payload_deref->var;
nir_src cull_mask = intr->src[2];
bool ignore_cull_mask = nir_src_is_const(cull_mask) && (nir_src_as_uint(cull_mask) & 0xFF) == 0xFF;
@@ -1609,6 +1631,20 @@ lower_rt_instruction_monolithic(nir_builder *b, nir_instr *instr, void *data)
}
}
static bool
radv_rewrite_call_data_deref_modes(nir_builder *b, nir_instr *instr, void *data)
{
if (instr->type != nir_instr_type_deref)
return false;
nir_deref_instr *deref = nir_instr_as_deref(instr);
if (!nir_deref_mode_is(deref, nir_var_shader_call_data))
return false;
deref->modes = nir_var_shader_temp;
return true;
}
static void
lower_rt_instructions_monolithic(nir_shader *shader, struct radv_device *device,
struct radv_ray_tracing_pipeline *pipeline,
@@ -1626,6 +1662,8 @@ lower_rt_instructions_monolithic(nir_shader *shader, struct radv_device *device,
nir_shader_instructions_pass(shader, lower_rt_instruction_monolithic, nir_metadata_none, &state);
nir_index_ssa_defs(impl);
nir_shader_instructions_pass(shader, radv_rewrite_call_data_deref_modes, nir_metadata_all, &state);
/* Register storage for hit attributes */
nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_SIZE / sizeof(uint32_t)];
@@ -1682,7 +1720,7 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH
const VkPipelineCreateFlagBits2KHR create_flags = vk_rt_pipeline_create_flags(pCreateInfo);
struct rt_variables vars = create_rt_variables(shader, device, create_flags);
struct rt_variables vars = create_rt_variables(shader, device, create_flags, monolithic);
if (monolithic)
lower_rt_instructions_monolithic(shader, device, pipeline, pCreateInfo, &vars);
@@ -1713,7 +1751,9 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH
nir_store_var(&b, vars.stack_ptr, ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base), 1);
nir_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record);
nir_store_var(&b, vars.shader_record_ptr, nir_pack_64_2x32(&b, record_ptr), 1);
nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1);
if (!monolithic)
nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1);
nir_def *accel_struct = ac_nir_load_arg(&b, &args->ac, args->ac.rt.accel_struct);
nir_store_var(&b, vars.accel_struct, nir_pack_64_2x32(&b, accel_struct), 1);
@@ -1786,4 +1826,7 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH
NIR_PASS_V(shader, nir_lower_vars_to_ssa);
if (shader->info.stage == MESA_SHADER_CLOSEST_HIT || shader->info.stage == MESA_SHADER_INTERSECTION)
NIR_PASS_V(shader, lower_hit_attribs, NULL, info->wave_size);
if (monolithic)
ac_nir_lower_indirect_derefs(shader, device->physical_device->rad_info.gfx_level);
}

View File

@@ -24,6 +24,7 @@
#include "nir/nir.h"
#include "nir/nir_builder.h"
#include "nir/radv_nir.h"
#include "radv_debug.h"
#include "radv_private.h"
#include "radv_shader.h"
@@ -365,6 +366,8 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
bool keep_executable_info = radv_pipeline_capture_shaders(device, pipeline->base.base.create_flags);
bool keep_statistic_info = radv_pipeline_capture_shader_stats(device, pipeline->base.base.create_flags);
radv_nir_lower_rt_io(stage->nir, monolithic);
/* Gather shader info. */
nir_shader_gather_info(stage->nir, nir_shader_get_entrypoint(stage->nir));
radv_nir_shader_info_init(stage->stage, MESA_SHADER_NONE, &stage->info);
@@ -517,7 +520,9 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
radv_pipeline_stage_init(&pCreateInfo->pStages[i], pipeline_layout, stage);
/* precompile the shader */
stage->nir = radv_parse_rt_stage(device, &pCreateInfo->pStages[i], key, pipeline_layout);
stage->nir = radv_shader_spirv_to_nir(device, stage, key, false);
NIR_PASS(_, stage->nir, radv_nir_lower_hit_attrib_derefs);
rt_stages[i].can_inline = radv_rt_can_inline_shader(stage->nir);

View File

@@ -725,9 +725,7 @@ void radv_postprocess_nir(struct radv_device *device, const struct radv_pipeline
bool radv_shader_should_clear_lds(const struct radv_device *device, const nir_shader *shader);
nir_shader *radv_parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo,
const struct radv_pipeline_key *key,
const struct radv_pipeline_layout *pipeline_layout);
void radv_nir_lower_rt_io(nir_shader *shader, bool monolithic);
void radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
const struct radv_shader_args *args, const struct radv_shader_info *info,