From e21774eb0aaefd150828fc8c21ff9f6c9ce380df Mon Sep 17 00:00:00 2001 From: Jesse Natalie Date: Wed, 14 Feb 2024 14:29:17 -0800 Subject: [PATCH] microsoft/compiler: Fix wave size control for SM6.6+ Part-of: --- src/microsoft/compiler/nir_to_dxil.c | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/microsoft/compiler/nir_to_dxil.c b/src/microsoft/compiler/nir_to_dxil.c index 583e23b4948..8d8d0fda87e 100644 --- a/src/microsoft/compiler/nir_to_dxil.c +++ b/src/microsoft/compiler/nir_to_dxil.c @@ -1814,6 +1814,14 @@ emit_threads(struct ntd_context *ctx) return dxil_get_metadata_node(&ctx->mod, threads_nodes, ARRAY_SIZE(threads_nodes)); } +static const struct dxil_mdnode * +emit_wave_size(struct ntd_context *ctx) +{ + const nir_shader *s = ctx->shader; + const struct dxil_mdnode *wave_size_node = dxil_get_metadata_int32(&ctx->mod, s->info.subgroup_size); + return dxil_get_metadata_node(&ctx->mod, &wave_size_node, 1); +} + static int64_t get_module_flags(struct ntd_context *ctx) { @@ -2040,8 +2048,7 @@ emit_metadata(struct ntd_context *ctx) return false; if (ctx->mod.minor_version >= 6 && ctx->shader->info.subgroup_size >= SUBGROUP_SIZE_REQUIRE_8 && - !emit_tag(ctx, DXIL_SHADER_TAG_WAVE_SIZE, - dxil_get_metadata_int32(&ctx->mod, ctx->shader->info.subgroup_size))) + !emit_tag(ctx, DXIL_SHADER_TAG_WAVE_SIZE, emit_wave_size(ctx))) return false; } @@ -6324,7 +6331,12 @@ void dxil_fill_validation_state(struct ntd_context *ctx, sizeof(struct dxil_resource_v1) : sizeof(struct dxil_resource_v0); state->num_resources = ctx->resources.size / resource_element_size; state->resources.v0 = (struct dxil_resource_v0*)ctx->resources.data; - state->state.psv1.psv0.max_expected_wave_lane_count = UINT_MAX; + if (ctx->shader->info.subgroup_size >= SUBGROUP_SIZE_REQUIRE_8) { + state->state.psv1.psv0.max_expected_wave_lane_count = ctx->shader->info.subgroup_size; + state->state.psv1.psv0.min_expected_wave_lane_count = ctx->shader->info.subgroup_size; + } else { + state->state.psv1.psv0.max_expected_wave_lane_count = UINT_MAX; + } state->state.psv1.shader_stage = (uint8_t)ctx->mod.shader_kind; state->state.psv1.uses_view_id = (uint8_t)ctx->mod.feats.view_id; state->state.psv1.sig_input_elements = (uint8_t)ctx->mod.num_sig_inputs;