radv: init the shader key in radv_shader_stage_init() for ESO

Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/27237>
This commit is contained in:
Samuel Pitoiset
2024-01-22 09:37:11 +01:00
committed by Marge Bot
parent 13add95beb
commit 9211eef738

View File

@@ -83,6 +83,22 @@ radv_shader_stage_init(const VkShaderCreateInfoEXT *sinfo, struct radv_shader_st
}
out_stage->layout.push_constant_size = align(out_stage->layout.push_constant_size, 16);
const VkShaderRequiredSubgroupSizeCreateInfoEXT *const subgroup_size =
vk_find_struct_const(sinfo->pNext, SHADER_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT);
if (subgroup_size) {
if (subgroup_size->requiredSubgroupSize == 32)
out_stage->key.subgroup_required_size = RADV_REQUIRED_WAVE32;
else if (subgroup_size->requiredSubgroupSize == 64)
out_stage->key.subgroup_required_size = RADV_REQUIRED_WAVE64;
else
unreachable("Unsupported required subgroup size.");
}
if (sinfo->flags & VK_SHADER_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT) {
out_stage->key.subgroup_require_full = 1;
}
}
static VkResult
@@ -114,22 +130,6 @@ radv_shader_object_init_graphics(struct radv_shader_object *shader_obj, struct r
if (device->physical_device->rad_info.gfx_level >= GFX11)
gfx_state.ms.alpha_to_coverage_via_mrtz = true;
const VkShaderRequiredSubgroupSizeCreateInfoEXT *const subgroup_size =
vk_find_struct_const(pCreateInfo->pNext, SHADER_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT);
if (subgroup_size) {
if (subgroup_size->requiredSubgroupSize == 32)
stages[stage].key.subgroup_required_size = RADV_REQUIRED_WAVE32;
else if (subgroup_size->requiredSubgroupSize == 64)
stages[stage].key.subgroup_required_size = RADV_REQUIRED_WAVE64;
else
unreachable("Unsupported required subgroup size.");
}
if (pCreateInfo->flags & VK_SHADER_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT) {
stages[stage].key.subgroup_require_full = 1;
}
struct radv_shader *shader = NULL;
struct radv_shader_binary *binary = NULL;
@@ -204,22 +204,6 @@ radv_shader_object_init_compute(struct radv_shader_object *shader_obj, struct ra
radv_shader_stage_init(pCreateInfo, &stage);
const VkShaderRequiredSubgroupSizeCreateInfoEXT *const subgroup_size =
vk_find_struct_const(pCreateInfo->pNext, SHADER_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT);
if (subgroup_size) {
if (subgroup_size->requiredSubgroupSize == 32)
stage.key.subgroup_required_size = RADV_REQUIRED_WAVE32;
else if (subgroup_size->requiredSubgroupSize == 64)
stage.key.subgroup_required_size = RADV_REQUIRED_WAVE64;
else
unreachable("Unsupported required subgroup size.");
}
if (pCreateInfo->flags & VK_SHADER_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT) {
stage.key.subgroup_require_full = 1;
}
struct radv_shader *cs_shader = radv_compile_cs(device, NULL, &stage, true, false, false, &cs_binary);
ralloc_free(stage.nir);
@@ -427,22 +411,6 @@ radv_shader_object_create_linked(VkDevice _device, uint32_t createInfoCount, con
gl_shader_stage s = vk_to_mesa_shader_stage(pCreateInfo->stage);
radv_shader_stage_init(pCreateInfo, &stages[s]);
const VkShaderRequiredSubgroupSizeCreateInfoEXT *const subgroup_size =
vk_find_struct_const(pCreateInfo->pNext, SHADER_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT);
if (subgroup_size) {
if (subgroup_size->requiredSubgroupSize == 32)
stages[s].key.subgroup_required_size = RADV_REQUIRED_WAVE32;
else if (subgroup_size->requiredSubgroupSize == 64)
stages[s].key.subgroup_required_size = RADV_REQUIRED_WAVE64;
else
unreachable("Unsupported required subgroup size.");
}
if (pCreateInfo->flags & VK_SHADER_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT) {
stages[s].key.subgroup_require_full = 1;
}
}
/* Determine next stage. */