diff --git a/src/amd/vulkan/nir/radv_nir_rt_shader.c b/src/amd/vulkan/nir/radv_nir_rt_shader.c index 90b39fb11d9..d78dfecbae3 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_shader.c +++ b/src/amd/vulkan/nir/radv_nir_rt_shader.c @@ -187,7 +187,6 @@ lower_rt_derefs(nir_shader *shader) struct rt_variables { struct radv_device *device; const VkPipelineCreateFlags2KHR flags; - bool monolithic; /* idx of the next shader to run in the next iteration of the main loop. * During traversal, idx is used to store the SBT index and will contain @@ -231,13 +230,11 @@ struct rt_variables { }; static struct rt_variables -create_rt_variables(nir_shader *shader, struct radv_device *device, const VkPipelineCreateFlags2KHR flags, - bool monolithic) +create_rt_variables(nir_shader *shader, struct radv_device *device, const VkPipelineCreateFlags2KHR flags) { struct rt_variables vars = { .device = device, .flags = flags, - .monolithic = monolithic, }; vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx"); vars.shader_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_addr"); @@ -793,7 +790,7 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni nir_opt_dead_cf(shader); - struct rt_variables src_vars = create_rt_variables(shader, vars->device, vars->flags, vars->monolithic); + struct rt_variables src_vars = create_rt_variables(shader, vars->device, vars->flags); map_rt_variables(var_remap, &src_vars, vars); NIR_PASS_V(shader, lower_rt_instructions, &src_vars, false); @@ -803,23 +800,6 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni inline_constants(b->shader, shader); - if (vars->monolithic) { - /* Work around vulkancts declaring an incompatible dummy payload. */ - nir_remove_dead_variables(shader, nir_var_shader_call_data, NULL); - - nir_foreach_variable_in_shader (var, shader) { - if (var->data.mode != nir_var_shader_call_data) - continue; - - /* There can only be one shader_call_data variable which has to match the caller payload. */ - if (var->type != vars->arg->type) - return; - - _mesa_hash_table_insert(var_remap, var, vars->arg); - break; - } - } - nir_push_if(b, nir_ieq_imm(b, idx, call_idx)); nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap); nir_pop_if(b, NULL); @@ -827,17 +807,25 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni ralloc_free(var_remap); } -void -radv_nir_lower_rt_io(nir_shader *nir, bool monolithic) +nir_shader * +radv_parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo, + const struct radv_pipeline_key *key, const struct radv_pipeline_layout *pipeline_layout) { - if (!monolithic) { - NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data, - glsl_get_natural_size_align_bytes); + struct radv_shader_stage rt_stage; - NIR_PASS(_, nir, lower_rt_derefs); + radv_pipeline_stage_init(sinfo, pipeline_layout, &rt_stage); - NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset); - } + nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, key, false); + + NIR_PASS(_, shader, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data, + glsl_get_natural_size_align_bytes); + + NIR_PASS(_, shader, lower_rt_derefs); + NIR_PASS(_, shader, radv_nir_lower_hit_attrib_derefs); + + NIR_PASS(_, shader, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset); + + return shader; } static nir_function_impl * @@ -1135,8 +1123,6 @@ radv_build_ahit_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_g radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir); assert(nir_stage); - radv_nir_lower_rt_io(nir_stage, data->vars->monolithic); - insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.any_hit_index); ralloc_free(nir_stage); } @@ -1159,16 +1145,12 @@ radv_build_isec_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_g radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->intersection_shader].nir); assert(nir_stage); - radv_nir_lower_rt_io(nir_stage, data->vars->monolithic); - nir_shader *any_hit_stage = NULL; if (group->any_hit_shader != VK_SHADER_UNUSED_KHR) { any_hit_stage = radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir); assert(any_hit_stage); - radv_nir_lower_rt_io(any_hit_stage, data->vars->monolithic); - /* reserve stack size for any_hit before it is inlined */ data->pipeline->stages[group->any_hit_shader].stack_size = any_hit_stage->scratch_size; @@ -1211,8 +1193,6 @@ radv_build_recursive_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_trac radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->recursive_shader].nir); assert(nir_stage); - radv_nir_lower_rt_io(nir_stage, data->vars->monolithic); - insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.general_index); ralloc_free(nir_stage); } @@ -1534,7 +1514,7 @@ radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_ b.shader->info.workgroup_size[0] = 8; b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4; b.shader->info.shared_size = device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t); - struct rt_variables vars = create_rt_variables(b.shader, device, create_flags, false); + struct rt_variables vars = create_rt_variables(b.shader, device, create_flags); /* initialize trace_ray arguments */ nir_store_var(&b, vars.accel_struct, nir_load_accel_struct_amd(&b), 1); @@ -1587,9 +1567,7 @@ lower_rt_instruction_monolithic(nir_builder *b, nir_instr *instr, void *data) case nir_intrinsic_execute_callable: unreachable("nir_intrinsic_execute_callable"); case nir_intrinsic_trace_ray: { - nir_deref_instr *payload_deref = nir_src_as_deref(intr->src[10]); - assert(payload_deref->deref_type == nir_deref_type_var); - vars->arg = payload_deref->var; + nir_store_var(b, vars->arg, nir_iadd_imm(b, intr->src[10].ssa, -b->shader->scratch_size), 1); nir_src cull_mask = intr->src[2]; bool ignore_cull_mask = nir_src_is_const(cull_mask) && (nir_src_as_uint(cull_mask) & 0xFF) == 0xFF; @@ -1631,20 +1609,6 @@ lower_rt_instruction_monolithic(nir_builder *b, nir_instr *instr, void *data) } } -static bool -radv_rewrite_call_data_deref_modes(nir_builder *b, nir_instr *instr, void *data) -{ - if (instr->type != nir_instr_type_deref) - return false; - - nir_deref_instr *deref = nir_instr_as_deref(instr); - if (!nir_deref_mode_is(deref, nir_var_shader_call_data)) - return false; - - deref->modes = nir_var_shader_temp; - return true; -} - static void lower_rt_instructions_monolithic(nir_shader *shader, struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline, @@ -1662,8 +1626,6 @@ lower_rt_instructions_monolithic(nir_shader *shader, struct radv_device *device, nir_shader_instructions_pass(shader, lower_rt_instruction_monolithic, nir_metadata_none, &state); nir_index_ssa_defs(impl); - nir_shader_instructions_pass(shader, radv_rewrite_call_data_deref_modes, nir_metadata_all, &state); - /* Register storage for hit attributes */ nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_SIZE / sizeof(uint32_t)]; @@ -1720,7 +1682,7 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH const VkPipelineCreateFlagBits2KHR create_flags = vk_rt_pipeline_create_flags(pCreateInfo); - struct rt_variables vars = create_rt_variables(shader, device, create_flags, monolithic); + struct rt_variables vars = create_rt_variables(shader, device, create_flags); if (monolithic) lower_rt_instructions_monolithic(shader, device, pipeline, pCreateInfo, &vars); @@ -1751,9 +1713,7 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH nir_store_var(&b, vars.stack_ptr, ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base), 1); nir_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record); nir_store_var(&b, vars.shader_record_ptr, nir_pack_64_2x32(&b, record_ptr), 1); - - if (!monolithic) - nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1); + nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1); nir_def *accel_struct = ac_nir_load_arg(&b, &args->ac, args->ac.rt.accel_struct); nir_store_var(&b, vars.accel_struct, nir_pack_64_2x32(&b, accel_struct), 1); @@ -1826,7 +1786,4 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH NIR_PASS_V(shader, nir_lower_vars_to_ssa); if (shader->info.stage == MESA_SHADER_CLOSEST_HIT || shader->info.stage == MESA_SHADER_INTERSECTION) NIR_PASS_V(shader, lower_hit_attribs, NULL, info->wave_size); - - if (monolithic) - ac_nir_lower_indirect_derefs(shader, device->physical_device->rad_info.gfx_level); } diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index c4ca2a2126b..04bb1edd493 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -24,7 +24,6 @@ #include "nir/nir.h" #include "nir/nir_builder.h" -#include "nir/radv_nir.h" #include "radv_debug.h" #include "radv_private.h" #include "radv_shader.h" @@ -366,8 +365,6 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache, bool keep_executable_info = radv_pipeline_capture_shaders(device, pipeline->base.base.create_flags); bool keep_statistic_info = radv_pipeline_capture_shader_stats(device, pipeline->base.base.create_flags); - radv_nir_lower_rt_io(stage->nir, monolithic); - /* Gather shader info. */ nir_shader_gather_info(stage->nir, nir_shader_get_entrypoint(stage->nir)); radv_nir_shader_info_init(stage->stage, MESA_SHADER_NONE, &stage->info); @@ -520,9 +517,7 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca radv_pipeline_stage_init(&pCreateInfo->pStages[i], pipeline_layout, stage); /* precompile the shader */ - stage->nir = radv_shader_spirv_to_nir(device, stage, key, false); - - NIR_PASS(_, stage->nir, radv_nir_lower_hit_attrib_derefs); + stage->nir = radv_parse_rt_stage(device, &pCreateInfo->pStages[i], key, pipeline_layout); rt_stages[i].can_inline = radv_rt_can_inline_shader(stage->nir); diff --git a/src/amd/vulkan/radv_shader.h b/src/amd/vulkan/radv_shader.h index 42c308e39cf..b0d04cc597b 100644 --- a/src/amd/vulkan/radv_shader.h +++ b/src/amd/vulkan/radv_shader.h @@ -726,7 +726,9 @@ void radv_postprocess_nir(struct radv_device *device, const struct radv_pipeline bool radv_shader_should_clear_lds(const struct radv_device *device, const nir_shader *shader); -void radv_nir_lower_rt_io(nir_shader *shader, bool monolithic); +nir_shader *radv_parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo, + const struct radv_pipeline_key *key, + const struct radv_pipeline_layout *pipeline_layout); void radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const struct radv_shader_args *args, const struct radv_shader_info *info,