diff --git a/src/amd/vulkan/radv_shader_object.c b/src/amd/vulkan/radv_shader_object.c index 475de93ce71..bfa0a363913 100644 --- a/src/amd/vulkan/radv_shader_object.c +++ b/src/amd/vulkan/radv_shader_object.c @@ -83,6 +83,22 @@ radv_shader_stage_init(const VkShaderCreateInfoEXT *sinfo, struct radv_shader_st } out_stage->layout.push_constant_size = align(out_stage->layout.push_constant_size, 16); + + const VkShaderRequiredSubgroupSizeCreateInfoEXT *const subgroup_size = + vk_find_struct_const(sinfo->pNext, SHADER_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT); + + if (subgroup_size) { + if (subgroup_size->requiredSubgroupSize == 32) + out_stage->key.subgroup_required_size = RADV_REQUIRED_WAVE32; + else if (subgroup_size->requiredSubgroupSize == 64) + out_stage->key.subgroup_required_size = RADV_REQUIRED_WAVE64; + else + unreachable("Unsupported required subgroup size."); + } + + if (sinfo->flags & VK_SHADER_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT) { + out_stage->key.subgroup_require_full = 1; + } } static VkResult @@ -114,22 +130,6 @@ radv_shader_object_init_graphics(struct radv_shader_object *shader_obj, struct r if (device->physical_device->rad_info.gfx_level >= GFX11) gfx_state.ms.alpha_to_coverage_via_mrtz = true; - const VkShaderRequiredSubgroupSizeCreateInfoEXT *const subgroup_size = - vk_find_struct_const(pCreateInfo->pNext, SHADER_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT); - - if (subgroup_size) { - if (subgroup_size->requiredSubgroupSize == 32) - stages[stage].key.subgroup_required_size = RADV_REQUIRED_WAVE32; - else if (subgroup_size->requiredSubgroupSize == 64) - stages[stage].key.subgroup_required_size = RADV_REQUIRED_WAVE64; - else - unreachable("Unsupported required subgroup size."); - } - - if (pCreateInfo->flags & VK_SHADER_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT) { - stages[stage].key.subgroup_require_full = 1; - } - struct radv_shader *shader = NULL; struct radv_shader_binary *binary = NULL; @@ -204,22 +204,6 @@ radv_shader_object_init_compute(struct radv_shader_object *shader_obj, struct ra radv_shader_stage_init(pCreateInfo, &stage); - const VkShaderRequiredSubgroupSizeCreateInfoEXT *const subgroup_size = - vk_find_struct_const(pCreateInfo->pNext, SHADER_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT); - - if (subgroup_size) { - if (subgroup_size->requiredSubgroupSize == 32) - stage.key.subgroup_required_size = RADV_REQUIRED_WAVE32; - else if (subgroup_size->requiredSubgroupSize == 64) - stage.key.subgroup_required_size = RADV_REQUIRED_WAVE64; - else - unreachable("Unsupported required subgroup size."); - } - - if (pCreateInfo->flags & VK_SHADER_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT) { - stage.key.subgroup_require_full = 1; - } - struct radv_shader *cs_shader = radv_compile_cs(device, NULL, &stage, true, false, false, &cs_binary); ralloc_free(stage.nir); @@ -427,22 +411,6 @@ radv_shader_object_create_linked(VkDevice _device, uint32_t createInfoCount, con gl_shader_stage s = vk_to_mesa_shader_stage(pCreateInfo->stage); radv_shader_stage_init(pCreateInfo, &stages[s]); - - const VkShaderRequiredSubgroupSizeCreateInfoEXT *const subgroup_size = - vk_find_struct_const(pCreateInfo->pNext, SHADER_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT); - - if (subgroup_size) { - if (subgroup_size->requiredSubgroupSize == 32) - stages[s].key.subgroup_required_size = RADV_REQUIRED_WAVE32; - else if (subgroup_size->requiredSubgroupSize == 64) - stages[s].key.subgroup_required_size = RADV_REQUIRED_WAVE64; - else - unreachable("Unsupported required subgroup size."); - } - - if (pCreateInfo->flags & VK_SHADER_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT) { - stages[s].key.subgroup_require_full = 1; - } } /* Determine next stage. */