radv/rt: use precompiled stages to create RT shader

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22100>
This commit is contained in:
Daniel Schürmann
2023-03-23 15:18:29 +01:00
committed by Marge Bot
parent 7836e32778
commit 8ec81a43cb
3 changed files with 39 additions and 14 deletions

View File

@@ -643,6 +643,13 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
if (result != VK_SUCCESS)
goto pipeline_fail;
struct radv_ray_tracing_stage *stages = calloc(local_create_info.stageCount, sizeof(*stages));
if (!stages) {
result = VK_ERROR_OUT_OF_HOST_MEMORY;
goto pipeline_fail;
}
radv_rt_fill_stage_info(pCreateInfo, stages);
const VkPipelineCreationFeedbackCreateInfo *creation_feedback =
vk_find_struct_const(pCreateInfo->pNext, PIPELINE_CREATION_FEEDBACK_CREATE_INFO);
@@ -664,7 +671,11 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
if (pCreateInfo->flags & VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT)
goto pipeline_fail;
shader = create_rt_shader(device, &local_create_info, rt_pipeline->groups, &key);
result = radv_rt_precompile_shaders(device, cache, pCreateInfo, &key, stages);
if (result != VK_SUCCESS)
goto shader_fail;
shader = create_rt_shader(device, &local_create_info, stages, rt_pipeline->groups, &key);
module.nir = shader;
result = radv_rt_pipeline_compile(rt_pipeline, pipeline_layout, device, cache, &key, &stage,
pCreateInfo->flags, hash, creation_feedback,
@@ -688,8 +699,15 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
radv_rmv_log_compute_pipeline_create(device, pCreateInfo->flags, &rt_pipeline->base.base, false);
*pPipeline = radv_pipeline_to_handle(&rt_pipeline->base.base);
shader_fail:
for (unsigned i = 0; stages && i < local_create_info.stageCount; i++) {
if (stages[i].shader)
vk_pipeline_cache_object_unref(&device->vk, stages[i].shader);
}
ralloc_free(shader);
free(stages);
pipeline_fail:
if (result != VK_SUCCESS)
radv_pipeline_destroy(device, &rt_pipeline->base.base, pAllocator);

View File

@@ -1190,6 +1190,7 @@ struct traversal_data {
nir_variable *barycentrics;
struct radv_ray_tracing_group *groups;
struct radv_ray_tracing_stage *stages;
const struct radv_pipeline_key *key;
};
@@ -1226,8 +1227,9 @@ visit_any_hit_shaders(struct radv_device *device,
if (is_dup)
continue;
const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
nir_shader *nir_stage = radv_parse_rt_stage(device, stage, data->key);
nir_shader *nir_stage =
radv_pipeline_cache_handle_to_nir(device, data->stages[shader_id].shader);
assert(nir_stage);
insert_rt_case(b, nir_stage, vars, sbt_idx, 0, data->groups[i].handle.any_hit_index,
shader_id, data->groups);
@@ -1363,13 +1365,15 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
if (is_dup)
continue;
const VkPipelineShaderStageCreateInfo *stage = &data->createInfo->pStages[shader_id];
nir_shader *nir_stage = radv_parse_rt_stage(data->device, stage, data->key);
nir_shader *nir_stage =
radv_pipeline_cache_handle_to_nir(data->device, data->stages[shader_id].shader);
assert(nir_stage);
nir_shader *any_hit_stage = NULL;
if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) {
stage = &data->createInfo->pStages[any_hit_shader_id];
any_hit_stage = radv_parse_rt_stage(data->device, stage, data->key);
any_hit_stage =
radv_pipeline_cache_handle_to_nir(data->device, data->stages[any_hit_shader_id].shader);
assert(any_hit_stage);
nir_lower_intersection_shader(nir_stage, any_hit_stage);
ralloc_free(any_hit_stage);
@@ -1421,7 +1425,7 @@ load_stack_entry(nir_builder *b, nir_ssa_def *index, const struct radv_ray_trave
}
static nir_shader *
build_traversal_shader(struct radv_device *device,
build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_stage *stages,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
struct radv_ray_tracing_group *groups, const struct radv_pipeline_key *key)
{
@@ -1515,6 +1519,7 @@ build_traversal_shader(struct radv_device *device,
.trav_vars = &trav_vars,
.barycentrics = barycentrics,
.groups = groups,
.stages = stages,
.key = key,
};
@@ -1619,7 +1624,8 @@ move_rt_instructions(nir_shader *shader)
nir_shader *
create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
struct radv_ray_tracing_group *groups, const struct radv_pipeline_key *key)
struct radv_ray_tracing_stage *stages, struct radv_ray_tracing_group *groups,
const struct radv_pipeline_key *key)
{
nir_builder b = radv_meta_init_shader(device, MESA_SHADER_RAYGEN, "rt_combined");
b.shader->info.internal = false;
@@ -1635,7 +1641,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
nir_ssa_def *idx = nir_load_var(&b, vars.idx);
/* Insert traversal shader */
nir_shader *traversal = build_traversal_shader(device, pCreateInfo, groups, key);
nir_shader *traversal = build_traversal_shader(device, stages, pCreateInfo, groups, key);
b.shader->info.shared_size = MAX2(b.shader->info.shared_size, traversal->info.shared_size);
assert(b.shader->info.shared_size <= 32768);
insert_rt_case(&b, traversal, &vars, idx, 0, 1, -1u, groups);
@@ -1657,13 +1663,12 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
if (is_dup)
continue;
const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[stage_idx];
ASSERTED gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage);
nir_shader *nir_stage = radv_pipeline_cache_handle_to_nir(device, stages[stage_idx].shader);
assert(nir_stage);
ASSERTED gl_shader_stage type = nir_stage->info.stage;
assert(type == MESA_SHADER_RAYGEN || type == MESA_SHADER_CALLABLE ||
type == MESA_SHADER_CLOSEST_HIT || type == MESA_SHADER_MISS);
nir_shader *nir_stage = radv_parse_rt_stage(device, stage, key);
/* Move ray tracing system values to the top that are set by rt_trace_ray
* to prevent them from being overwritten by other rt_trace_ray calls.
*/

View File

@@ -43,6 +43,7 @@
struct radv_physical_device;
struct radv_device;
struct radv_pipeline;
struct radv_ray_tracing_stage;
struct radv_ray_tracing_group;
struct radv_pipeline_key;
struct radv_shader_args;
@@ -765,6 +766,7 @@ void radv_get_nir_options(struct radv_physical_device *device);
nir_shader *create_rt_shader(struct radv_device *device,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
struct radv_ray_tracing_stage *stages,
struct radv_ray_tracing_group *groups,
const struct radv_pipeline_key *key);