diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index 306ad5ca873..b4bea1220e6 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -643,6 +643,13 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, if (result != VK_SUCCESS) goto pipeline_fail; + struct radv_ray_tracing_stage *stages = calloc(local_create_info.stageCount, sizeof(*stages)); + if (!stages) { + result = VK_ERROR_OUT_OF_HOST_MEMORY; + goto pipeline_fail; + } + radv_rt_fill_stage_info(pCreateInfo, stages); + const VkPipelineCreationFeedbackCreateInfo *creation_feedback = vk_find_struct_const(pCreateInfo->pNext, PIPELINE_CREATION_FEEDBACK_CREATE_INFO); @@ -664,7 +671,11 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, if (pCreateInfo->flags & VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT) goto pipeline_fail; - shader = create_rt_shader(device, &local_create_info, rt_pipeline->groups, &key); + result = radv_rt_precompile_shaders(device, cache, pCreateInfo, &key, stages); + if (result != VK_SUCCESS) + goto shader_fail; + + shader = create_rt_shader(device, &local_create_info, stages, rt_pipeline->groups, &key); module.nir = shader; result = radv_rt_pipeline_compile(rt_pipeline, pipeline_layout, device, cache, &key, &stage, pCreateInfo->flags, hash, creation_feedback, @@ -688,8 +699,15 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, radv_rmv_log_compute_pipeline_create(device, pCreateInfo->flags, &rt_pipeline->base.base, false); *pPipeline = radv_pipeline_to_handle(&rt_pipeline->base.base); + shader_fail: + for (unsigned i = 0; stages && i < local_create_info.stageCount; i++) { + if (stages[i].shader) + vk_pipeline_cache_object_unref(&device->vk, stages[i].shader); + } ralloc_free(shader); + free(stages); + pipeline_fail: if (result != VK_SUCCESS) radv_pipeline_destroy(device, &rt_pipeline->base.base, pAllocator); diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c index 52dfd24c408..19a5997927b 100644 --- a/src/amd/vulkan/radv_rt_shader.c +++ b/src/amd/vulkan/radv_rt_shader.c @@ -1190,6 +1190,7 @@ struct traversal_data { nir_variable *barycentrics; struct radv_ray_tracing_group *groups; + struct radv_ray_tracing_stage *stages; const struct radv_pipeline_key *key; }; @@ -1226,8 +1227,9 @@ visit_any_hit_shaders(struct radv_device *device, if (is_dup) continue; - const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id]; - nir_shader *nir_stage = radv_parse_rt_stage(device, stage, data->key); + nir_shader *nir_stage = + radv_pipeline_cache_handle_to_nir(device, data->stages[shader_id].shader); + assert(nir_stage); insert_rt_case(b, nir_stage, vars, sbt_idx, 0, data->groups[i].handle.any_hit_index, shader_id, data->groups); @@ -1363,13 +1365,15 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio if (is_dup) continue; - const VkPipelineShaderStageCreateInfo *stage = &data->createInfo->pStages[shader_id]; - nir_shader *nir_stage = radv_parse_rt_stage(data->device, stage, data->key); + nir_shader *nir_stage = + radv_pipeline_cache_handle_to_nir(data->device, data->stages[shader_id].shader); + assert(nir_stage); nir_shader *any_hit_stage = NULL; if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) { - stage = &data->createInfo->pStages[any_hit_shader_id]; - any_hit_stage = radv_parse_rt_stage(data->device, stage, data->key); + any_hit_stage = + radv_pipeline_cache_handle_to_nir(data->device, data->stages[any_hit_shader_id].shader); + assert(any_hit_stage); nir_lower_intersection_shader(nir_stage, any_hit_stage); ralloc_free(any_hit_stage); @@ -1421,7 +1425,7 @@ load_stack_entry(nir_builder *b, nir_ssa_def *index, const struct radv_ray_trave } static nir_shader * -build_traversal_shader(struct radv_device *device, +build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_stage *stages, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_ray_tracing_group *groups, const struct radv_pipeline_key *key) { @@ -1515,6 +1519,7 @@ build_traversal_shader(struct radv_device *device, .trav_vars = &trav_vars, .barycentrics = barycentrics, .groups = groups, + .stages = stages, .key = key, }; @@ -1619,7 +1624,8 @@ move_rt_instructions(nir_shader *shader) nir_shader * create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, - struct radv_ray_tracing_group *groups, const struct radv_pipeline_key *key) + struct radv_ray_tracing_stage *stages, struct radv_ray_tracing_group *groups, + const struct radv_pipeline_key *key) { nir_builder b = radv_meta_init_shader(device, MESA_SHADER_RAYGEN, "rt_combined"); b.shader->info.internal = false; @@ -1635,7 +1641,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf nir_ssa_def *idx = nir_load_var(&b, vars.idx); /* Insert traversal shader */ - nir_shader *traversal = build_traversal_shader(device, pCreateInfo, groups, key); + nir_shader *traversal = build_traversal_shader(device, stages, pCreateInfo, groups, key); b.shader->info.shared_size = MAX2(b.shader->info.shared_size, traversal->info.shared_size); assert(b.shader->info.shared_size <= 32768); insert_rt_case(&b, traversal, &vars, idx, 0, 1, -1u, groups); @@ -1657,13 +1663,12 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf if (is_dup) continue; - const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[stage_idx]; - ASSERTED gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage); + nir_shader *nir_stage = radv_pipeline_cache_handle_to_nir(device, stages[stage_idx].shader); + assert(nir_stage); + ASSERTED gl_shader_stage type = nir_stage->info.stage; assert(type == MESA_SHADER_RAYGEN || type == MESA_SHADER_CALLABLE || type == MESA_SHADER_CLOSEST_HIT || type == MESA_SHADER_MISS); - nir_shader *nir_stage = radv_parse_rt_stage(device, stage, key); - /* Move ray tracing system values to the top that are set by rt_trace_ray * to prevent them from being overwritten by other rt_trace_ray calls. */ diff --git a/src/amd/vulkan/radv_shader.h b/src/amd/vulkan/radv_shader.h index 4559c74b068..96fc1422819 100644 --- a/src/amd/vulkan/radv_shader.h +++ b/src/amd/vulkan/radv_shader.h @@ -43,6 +43,7 @@ struct radv_physical_device; struct radv_device; struct radv_pipeline; +struct radv_ray_tracing_stage; struct radv_ray_tracing_group; struct radv_pipeline_key; struct radv_shader_args; @@ -765,6 +766,7 @@ void radv_get_nir_options(struct radv_physical_device *device); nir_shader *create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, + struct radv_ray_tracing_stage *stages, struct radv_ray_tracing_group *groups, const struct radv_pipeline_key *key);