diff --git a/src/microsoft/compiler/dxil_container.c b/src/microsoft/compiler/dxil_container.c index 73b7f38fd04..996affbe4fc 100644 --- a/src/microsoft/compiler/dxil_container.c +++ b/src/microsoft/compiler/dxil_container.c @@ -197,12 +197,14 @@ dxil_container_add_state_validation(struct dxil_container *c, const struct dxil_module *m, struct dxil_validation_state *state) { - uint32_t psv1_size = sizeof(struct dxil_psv_runtime_info_1); + uint32_t psv_size = m->minor_validator >= 6 ? + sizeof(struct dxil_psv_runtime_info_2) : + sizeof(struct dxil_psv_runtime_info_1); uint32_t resource_bind_info_size = 4 * sizeof(uint32_t); uint32_t dxil_pvs_sig_size = sizeof(struct dxil_psv_signature_element); uint32_t resource_count = state->num_resources; - uint32_t size = psv1_size + 2 * sizeof(uint32_t); + uint32_t size = psv_size + 2 * sizeof(uint32_t); if (resource_count > 0) { size += sizeof (uint32_t) + resource_bind_info_size * resource_count; @@ -220,31 +222,31 @@ dxil_container_add_state_validation(struct dxil_container *c, size += dxil_pvs_sig_size * m->num_sig_outputs; size += dxil_pvs_sig_size * m->num_sig_patch_consts; - state->state.sig_input_vectors = (uint8_t)m->num_psv_inputs; + state->state.psv1.sig_input_vectors = (uint8_t)m->num_psv_inputs; for (unsigned i = 0; i < 4; ++i) - state->state.sig_output_vectors[i] = (uint8_t)m->num_psv_outputs[i]; + state->state.psv1.sig_output_vectors[i] = (uint8_t)m->num_psv_outputs[i]; // TODO: Add viewID records size uint32_t dependency_table_size = 0; - if (state->state.sig_input_vectors > 0) { + if (state->state.psv1.sig_input_vectors > 0) { for (unsigned i = 0; i < 4; ++i) { - if (state->state.sig_output_vectors[i] > 0) + if (state->state.psv1.sig_output_vectors[i] > 0) dependency_table_size += sizeof(uint32_t) * - compute_input_output_table_dwords(state->state.sig_input_vectors, - state->state.sig_output_vectors[i]); + compute_input_output_table_dwords(state->state.psv1.sig_input_vectors, + state->state.psv1.sig_output_vectors[i]); } - if (state->state.shader_stage == DXIL_HULL_SHADER && state->state.sig_patch_const_or_prim_vectors) { - dependency_table_size += sizeof(uint32_t) * compute_input_output_table_dwords(state->state.sig_input_vectors, - state->state.sig_patch_const_or_prim_vectors); + if (state->state.psv1.shader_stage == DXIL_HULL_SHADER && state->state.psv1.sig_patch_const_or_prim_vectors) { + dependency_table_size += sizeof(uint32_t) * compute_input_output_table_dwords(state->state.psv1.sig_input_vectors, + state->state.psv1.sig_patch_const_or_prim_vectors); } } - if (state->state.shader_stage == DXIL_DOMAIN_SHADER && - state->state.sig_patch_const_or_prim_vectors && - state->state.sig_output_vectors[0]) { + if (state->state.psv1.shader_stage == DXIL_DOMAIN_SHADER && + state->state.psv1.sig_patch_const_or_prim_vectors && + state->state.psv1.sig_output_vectors[0]) { dependency_table_size += sizeof(uint32_t) * compute_input_output_table_dwords( - state->state.sig_patch_const_or_prim_vectors, state->state.sig_output_vectors[0]); + state->state.psv1.sig_patch_const_or_prim_vectors, state->state.psv1.sig_output_vectors[0]); } size += dependency_table_size; // TODO: Domain shader table goes here @@ -252,10 +254,10 @@ dxil_container_add_state_validation(struct dxil_container *c, if (!add_part_header(c, DXIL_PSV0, size)) return false; - if (!blob_write_bytes(&c->parts, &psv1_size, sizeof(psv1_size))) + if (!blob_write_bytes(&c->parts, &psv_size, sizeof(psv_size))) return false; - if (!blob_write_bytes(&c->parts, &state->state, psv1_size)) + if (!blob_write_bytes(&c->parts, &state->state, psv_size)) return false; if (!blob_write_bytes(&c->parts, &resource_count, sizeof(resource_count))) diff --git a/src/microsoft/compiler/dxil_container.h b/src/microsoft/compiler/dxil_container.h index 23787b56e73..29eb1160533 100644 --- a/src/microsoft/compiler/dxil_container.h +++ b/src/microsoft/compiler/dxil_container.h @@ -82,7 +82,7 @@ struct dxil_resource { }; struct dxil_validation_state { - struct dxil_psv_runtime_info_1 state; + struct dxil_psv_runtime_info_2 state; const struct dxil_resource *resources; uint32_t num_resources; }; diff --git a/src/microsoft/compiler/dxil_signature.h b/src/microsoft/compiler/dxil_signature.h index 567af9127fa..0ed708d9517 100644 --- a/src/microsoft/compiler/dxil_signature.h +++ b/src/microsoft/compiler/dxil_signature.h @@ -141,6 +141,13 @@ struct dxil_psv_runtime_info_1 { uint8_t sig_output_vectors[4]; }; +struct dxil_psv_runtime_info_2 { + struct dxil_psv_runtime_info_1 psv1; + uint32_t num_threads_x; + uint32_t num_threads_y; + uint32_t num_threads_z; +}; + struct dxil_mdnode; struct dxil_module; diff --git a/src/microsoft/compiler/nir_to_dxil.c b/src/microsoft/compiler/nir_to_dxil.c index aeb51d2a91d..15f98e94f17 100644 --- a/src/microsoft/compiler/nir_to_dxil.c +++ b/src/microsoft/compiler/nir_to_dxil.c @@ -5609,43 +5609,46 @@ void dxil_fill_validation_state(struct ntd_context *ctx, { state->num_resources = util_dynarray_num_elements(&ctx->resources, struct dxil_resource); state->resources = (struct dxil_resource*)ctx->resources.data; - state->state.psv0.max_expected_wave_lane_count = UINT_MAX; - state->state.shader_stage = (uint8_t)ctx->mod.shader_kind; - state->state.sig_input_elements = (uint8_t)ctx->mod.num_sig_inputs; - state->state.sig_output_elements = (uint8_t)ctx->mod.num_sig_outputs; - state->state.sig_patch_const_or_prim_elements = (uint8_t)ctx->mod.num_sig_patch_consts; + state->state.psv1.psv0.max_expected_wave_lane_count = UINT_MAX; + state->state.psv1.shader_stage = (uint8_t)ctx->mod.shader_kind; + state->state.psv1.sig_input_elements = (uint8_t)ctx->mod.num_sig_inputs; + state->state.psv1.sig_output_elements = (uint8_t)ctx->mod.num_sig_outputs; + state->state.psv1.sig_patch_const_or_prim_elements = (uint8_t)ctx->mod.num_sig_patch_consts; switch (ctx->mod.shader_kind) { case DXIL_VERTEX_SHADER: - state->state.psv0.vs.output_position_present = ctx->mod.info.has_out_position; + state->state.psv1.psv0.vs.output_position_present = ctx->mod.info.has_out_position; break; case DXIL_PIXEL_SHADER: /* TODO: handle depth outputs */ - state->state.psv0.ps.depth_output = ctx->mod.info.has_out_depth; - state->state.psv0.ps.sample_frequency = + state->state.psv1.psv0.ps.depth_output = ctx->mod.info.has_out_depth; + state->state.psv1.psv0.ps.sample_frequency = ctx->mod.info.has_per_sample_input; break; case DXIL_COMPUTE_SHADER: + state->state.num_threads_x = ctx->shader->info.workgroup_size[0]; + state->state.num_threads_y = ctx->shader->info.workgroup_size[1]; + state->state.num_threads_z = ctx->shader->info.workgroup_size[2]; break; case DXIL_GEOMETRY_SHADER: - state->state.max_vertex_count = ctx->shader->info.gs.vertices_out; - state->state.psv0.gs.input_primitive = dxil_get_input_primitive(ctx->shader->info.gs.input_primitive); - state->state.psv0.gs.output_toplology = dxil_get_primitive_topology(ctx->shader->info.gs.output_primitive); - state->state.psv0.gs.output_stream_mask = MAX2(ctx->shader->info.gs.active_stream_mask, 1); - state->state.psv0.gs.output_position_present = ctx->mod.info.has_out_position; + state->state.psv1.max_vertex_count = ctx->shader->info.gs.vertices_out; + state->state.psv1.psv0.gs.input_primitive = dxil_get_input_primitive(ctx->shader->info.gs.input_primitive); + state->state.psv1.psv0.gs.output_toplology = dxil_get_primitive_topology(ctx->shader->info.gs.output_primitive); + state->state.psv1.psv0.gs.output_stream_mask = MAX2(ctx->shader->info.gs.active_stream_mask, 1); + state->state.psv1.psv0.gs.output_position_present = ctx->mod.info.has_out_position; break; case DXIL_HULL_SHADER: - state->state.psv0.hs.input_control_point_count = ctx->tess_input_control_point_count; - state->state.psv0.hs.output_control_point_count = ctx->shader->info.tess.tcs_vertices_out; - state->state.psv0.hs.tessellator_domain = get_tessellator_domain(ctx->shader->info.tess._primitive_mode); - state->state.psv0.hs.tessellator_output_primitive = get_tessellator_output_primitive(&ctx->shader->info); - state->state.sig_patch_const_or_prim_vectors = ctx->mod.num_psv_patch_consts; + state->state.psv1.psv0.hs.input_control_point_count = ctx->tess_input_control_point_count; + state->state.psv1.psv0.hs.output_control_point_count = ctx->shader->info.tess.tcs_vertices_out; + state->state.psv1.psv0.hs.tessellator_domain = get_tessellator_domain(ctx->shader->info.tess._primitive_mode); + state->state.psv1.psv0.hs.tessellator_output_primitive = get_tessellator_output_primitive(&ctx->shader->info); + state->state.psv1.sig_patch_const_or_prim_vectors = ctx->mod.num_psv_patch_consts; break; case DXIL_DOMAIN_SHADER: - state->state.psv0.ds.input_control_point_count = ctx->shader->info.tess.tcs_vertices_out; - state->state.psv0.ds.tessellator_domain = get_tessellator_domain(ctx->shader->info.tess._primitive_mode); - state->state.psv0.ds.output_position_present = ctx->mod.info.has_out_position; - state->state.sig_patch_const_or_prim_vectors = ctx->mod.num_psv_patch_consts; + state->state.psv1.psv0.ds.input_control_point_count = ctx->shader->info.tess.tcs_vertices_out; + state->state.psv1.psv0.ds.tessellator_domain = get_tessellator_domain(ctx->shader->info.tess._primitive_mode); + state->state.psv1.psv0.ds.output_position_present = ctx->mod.info.has_out_position; + state->state.psv1.sig_patch_const_or_prim_vectors = ctx->mod.num_psv_patch_consts; break; default: assert(0 && "Shader type not (yet) supported");