diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.cpp b/src/gallium/drivers/d3d12/d3d12_compiler.cpp index a2979435b96..815ae8ee85e 100644 --- a/src/gallium/drivers/d3d12/d3d12_compiler.cpp +++ b/src/gallium/drivers/d3d12/d3d12_compiler.cpp @@ -135,7 +135,7 @@ compile_nir(struct d3d12_context *ctx, struct d3d12_shader_selector *sel, sel->current = shader; NIR_PASS_V(nir, nir_lower_samplers); - NIR_PASS_V(nir, d3d12_create_bare_samplers); + NIR_PASS_V(nir, dxil_nir_create_bare_samplers); if (key->samples_int_textures) NIR_PASS_V(nir, dxil_lower_sample_to_txf_for_integer_tex, diff --git a/src/gallium/drivers/d3d12/d3d12_nir_passes.c b/src/gallium/drivers/d3d12/d3d12_nir_passes.c index 5b28c02d1ff..a1239dc803e 100644 --- a/src/gallium/drivers/d3d12/d3d12_nir_passes.c +++ b/src/gallium/drivers/d3d12/d3d12_nir_passes.c @@ -542,39 +542,6 @@ d3d12_lower_state_vars(nir_shader *nir, struct d3d12_shader *shader) return progress; } -static const struct glsl_type * -get_bare_samplers_for_type(const struct glsl_type *type) -{ - if (glsl_type_is_sampler(type)) { - if (glsl_sampler_type_is_shadow(type)) - return glsl_bare_shadow_sampler_type(); - else - return glsl_bare_sampler_type(); - } else if (glsl_type_is_array(type)) { - return glsl_array_type( - get_bare_samplers_for_type(glsl_get_array_element(type)), - glsl_get_length(type), - 0 /*explicit size*/); - } - assert(!"Unexpected type"); - return NULL; -} - -void -d3d12_create_bare_samplers(nir_shader *nir) -{ - nir_foreach_variable_with_modes_safe(var, nir, nir_var_uniform) { - const struct glsl_type *type = glsl_without_array(var->type); - if (glsl_type_is_sampler(type) && glsl_get_sampler_result_type(type) != GLSL_TYPE_VOID) { - /* Since samplers are already lowered to be accessed by index, all we need to do - * here is create a bare sampler with the same binding */ - nir_variable *clone = nir_variable_clone(var, nir); - clone->type = get_bare_samplers_for_type(var->type); - nir_shader_add_variable(nir, clone); - } - } -} - static bool lower_bool_input_filter(const nir_instr *instr, UNUSED const void *_options) diff --git a/src/gallium/drivers/d3d12/d3d12_nir_passes.h b/src/gallium/drivers/d3d12/d3d12_nir_passes.h index 5f10b29f200..38d36206caf 100644 --- a/src/gallium/drivers/d3d12/d3d12_nir_passes.h +++ b/src/gallium/drivers/d3d12/d3d12_nir_passes.h @@ -54,9 +54,6 @@ d3d12_lower_depth_range(nir_shader *nir); bool d3d12_lower_load_first_vertex(nir_shader *nir); -void -d3d12_create_bare_samplers(nir_shader *s); - bool d3d12_lower_bool_input(struct nir_shader *s); diff --git a/src/microsoft/compiler/dxil_nir.c b/src/microsoft/compiler/dxil_nir.c index aac3c3116de..6659867d802 100644 --- a/src/microsoft/compiler/dxil_nir.c +++ b/src/microsoft/compiler/dxil_nir.c @@ -1355,3 +1355,113 @@ dxil_nir_lower_system_values_to_zero(nir_shader* shader, lower_system_value_to_zero_instr, &state); } + +static const struct glsl_type * +get_bare_samplers_for_type(const struct glsl_type *type) +{ + if (glsl_type_is_sampler(type)) { + if (glsl_sampler_type_is_shadow(type)) + return glsl_bare_shadow_sampler_type(); + else + return glsl_bare_sampler_type(); + } else if (glsl_type_is_array(type)) { + return glsl_array_type( + get_bare_samplers_for_type(glsl_get_array_element(type)), + glsl_get_length(type), + 0 /*explicit size*/); + } + assert(!"Unexpected type"); + return NULL; +} + +static bool +redirect_sampler_derefs(struct nir_builder *b, nir_instr *instr, void *data) +{ + if (instr->type != nir_instr_type_tex) + return false; + + nir_tex_instr *tex = nir_instr_as_tex(instr); + if (!nir_tex_instr_need_sampler(tex)) + return false; + + int sampler_idx = nir_tex_instr_src_index(tex, nir_tex_src_sampler_deref); + if (sampler_idx == -1) { + /* No derefs, must be using indices */ + struct hash_entry *hash_entry = _mesa_hash_table_u64_search(data, tex->sampler_index); + + /* Already have a bare sampler here */ + if (hash_entry) + return false; + + nir_variable *typed_sampler = NULL; + nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) { + if (var->data.binding <= tex->sampler_index && + var->data.binding + glsl_type_get_sampler_count(var->type) > tex->sampler_index) { + /* Already have a bare sampler for this binding, add it to the table */ + if (glsl_get_sampler_result_type(glsl_without_array(var->type)) == GLSL_TYPE_VOID) { + _mesa_hash_table_u64_insert(data, tex->sampler_index, var); + return false; + } + + typed_sampler = var; + } + } + + /* Clone the typed sampler to a bare sampler and we're done */ + assert(typed_sampler); + nir_variable *bare_sampler = nir_variable_clone(typed_sampler, b->shader); + bare_sampler->type = get_bare_samplers_for_type(typed_sampler->type); + nir_shader_add_variable(b->shader, bare_sampler); + _mesa_hash_table_u64_insert(data, tex->sampler_index, bare_sampler); + return true; + } + + /* Using derefs, means we have to rewrite the deref chain in addition to cloning */ + nir_deref_instr *final_deref = nir_src_as_deref(tex->src[sampler_idx].src); + nir_deref_path path; + nir_deref_path_init(&path, final_deref, NULL); + + nir_deref_instr *old_tail = path.path[0]; + assert(old_tail->deref_type == nir_deref_type_var); + nir_variable *old_var = old_tail->var; + if (glsl_get_sampler_result_type(glsl_without_array(old_var->type)) == GLSL_TYPE_VOID) { + nir_deref_path_finish(&path); + return false; + } + + struct hash_entry *hash_entry = _mesa_hash_table_u64_search(data, old_var->data.binding); + nir_variable *new_var; + if (hash_entry) { + new_var = hash_entry->data; + } else { + new_var = nir_variable_clone(old_var, b->shader); + new_var->type = get_bare_samplers_for_type(old_var->type); + nir_shader_add_variable(b->shader, new_var); + _mesa_hash_table_u64_insert(data, old_var->data.binding, new_var); + } + + b->cursor = nir_after_instr(&old_tail->instr); + nir_deref_instr *new_tail = nir_build_deref_var(b, new_var); + + for (unsigned i = 1; path.path[i]; ++i) { + b->cursor = nir_after_instr(&path.path[i]->instr); + new_tail = nir_build_deref_follower(b, new_tail, path.path[i]); + } + + nir_deref_path_finish(&path); + nir_instr_rewrite_src_ssa(&tex->instr, &tex->src[sampler_idx].src, &new_tail->dest.ssa); + + return true; +} + +bool +dxil_nir_create_bare_samplers(nir_shader *nir) +{ + struct hash_table_u64 *sampler_to_bare = _mesa_hash_table_u64_create(NULL); + + bool progress = nir_shader_instructions_pass(nir, redirect_sampler_derefs, + nir_metadata_block_index | nir_metadata_dominance | nir_metadata_loop_analysis, sampler_to_bare); + + _mesa_hash_table_u64_destroy(sampler_to_bare, NULL); + return progress; +} diff --git a/src/microsoft/compiler/dxil_nir.h b/src/microsoft/compiler/dxil_nir.h index 84d4ca32e34..00d35a20a13 100644 --- a/src/microsoft/compiler/dxil_nir.h +++ b/src/microsoft/compiler/dxil_nir.h @@ -49,6 +49,7 @@ bool dxil_nir_lower_double_math(nir_shader *shader); bool dxil_nir_lower_system_values_to_zero(nir_shader *shader, gl_system_value* system_value, uint32_t count); +bool dxil_nir_create_bare_samplers(nir_shader *shader); nir_ssa_def * build_load_ubo_dxil(nir_builder *b, nir_ssa_def *buffer, diff --git a/src/microsoft/spirv_to_dxil/spirv_to_dxil.c b/src/microsoft/spirv_to_dxil/spirv_to_dxil.c index 2eb25da67af..c5462dd1918 100644 --- a/src/microsoft/spirv_to_dxil/spirv_to_dxil.c +++ b/src/microsoft/spirv_to_dxil/spirv_to_dxil.c @@ -127,6 +127,7 @@ spirv_to_dxil(const uint32_t *words, size_t word_count, NIR_PASS_V(nir, nir_lower_readonly_images_to_tex, true); NIR_PASS_V(nir, dxil_nir_split_clip_cull_distance); NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil); + NIR_PASS_V(nir, dxil_nir_create_bare_samplers); nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));