lavapipe: fix mesh+task binding with shader objects

if mesh and task shaders are bound separately, and if they have different
workgroup sizes, the setting of workgroup size will be broken if
set during shader bind

this must be deferred to draw time to pull the correct values

cc: mesa-stable

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29733>
This commit is contained in:
Mike Blumenkrantz
2024-06-14 10:37:38 -04:00
committed by Marge Bot
parent 7bea6f8612
commit 2bb35bf489

View File

@@ -730,19 +730,11 @@ handle_graphics_stages(struct rendering_state *state, VkShaderStageFlagBits shad
break;
case VK_SHADER_STAGE_TASK_BIT_EXT:
state->inlines_dirty[MESA_SHADER_TASK] = state->shaders[MESA_SHADER_TASK]->inlines.can_inline;
state->dispatch_info.block[0] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[0];
state->dispatch_info.block[1] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[1];
state->dispatch_info.block[2] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[2];
if (!state->shaders[MESA_SHADER_TASK]->inlines.can_inline)
state->pctx->bind_ts_state(state->pctx, state->shaders[MESA_SHADER_TASK]->shader_cso);
break;
case VK_SHADER_STAGE_MESH_BIT_EXT:
state->inlines_dirty[MESA_SHADER_MESH] = state->shaders[MESA_SHADER_MESH]->inlines.can_inline;
if (!(shader_stages & VK_SHADER_STAGE_TASK_BIT_EXT)) {
state->dispatch_info.block[0] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[0];
state->dispatch_info.block[1] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[1];
state->dispatch_info.block[2] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[2];
}
if (!state->shaders[MESA_SHADER_MESH]->inlines.can_inline)
state->pctx->bind_ms_state(state->pctx, state->shaders[MESA_SHADER_MESH]->shader_cso);
break;
@@ -3863,9 +3855,24 @@ handle_shaders(struct vk_cmd_queue_entry *cmd, struct rendering_state *state)
}
}
static void
update_mesh_state(struct rendering_state *state)
{
if (state->shaders[MESA_SHADER_TASK]) {
state->dispatch_info.block[0] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[0];
state->dispatch_info.block[1] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[1];
state->dispatch_info.block[2] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[2];
} else {
state->dispatch_info.block[0] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[0];
state->dispatch_info.block[1] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[1];
state->dispatch_info.block[2] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[2];
}
}
static void handle_draw_mesh_tasks(struct vk_cmd_queue_entry *cmd,
struct rendering_state *state)
{
update_mesh_state(state);
state->dispatch_info.grid[0] = cmd->u.draw_mesh_tasks_ext.group_count_x;
state->dispatch_info.grid[1] = cmd->u.draw_mesh_tasks_ext.group_count_y;
state->dispatch_info.grid[2] = cmd->u.draw_mesh_tasks_ext.group_count_z;
@@ -3880,6 +3887,7 @@ static void handle_draw_mesh_tasks(struct vk_cmd_queue_entry *cmd,
static void handle_draw_mesh_tasks_indirect(struct vk_cmd_queue_entry *cmd,
struct rendering_state *state)
{
update_mesh_state(state);
state->dispatch_info.indirect = lvp_buffer_from_handle(cmd->u.draw_mesh_tasks_indirect_ext.buffer)->bo;
state->dispatch_info.indirect_offset = cmd->u.draw_mesh_tasks_indirect_ext.offset;
state->dispatch_info.indirect_stride = cmd->u.draw_mesh_tasks_indirect_ext.stride;
@@ -3890,6 +3898,7 @@ static void handle_draw_mesh_tasks_indirect(struct vk_cmd_queue_entry *cmd,
static void handle_draw_mesh_tasks_indirect_count(struct vk_cmd_queue_entry *cmd,
struct rendering_state *state)
{
update_mesh_state(state);
state->dispatch_info.indirect = lvp_buffer_from_handle(cmd->u.draw_mesh_tasks_indirect_count_ext.buffer)->bo;
state->dispatch_info.indirect_offset = cmd->u.draw_mesh_tasks_indirect_count_ext.offset;
state->dispatch_info.indirect_stride = cmd->u.draw_mesh_tasks_indirect_count_ext.stride;