diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.cpp b/src/gallium/drivers/d3d12/d3d12_compiler.cpp index 90849551edd..ec43465558c 100644 --- a/src/gallium/drivers/d3d12/d3d12_compiler.cpp +++ b/src/gallium/drivers/d3d12/d3d12_compiler.cpp @@ -633,6 +633,9 @@ d3d12_compare_shader_keys(const d3d12_shader_key *expect, const d3d12_shader_key if (expect->n_texture_states != have->n_texture_states) return false; + if (expect->n_images != have->n_images) + return false; + if (memcmp(expect->tex_wrap_states, have->tex_wrap_states, expect->n_texture_states * sizeof(dxil_wrap_sampler_state))) return false; @@ -645,6 +648,10 @@ d3d12_compare_shader_keys(const d3d12_shader_key *expect, const d3d12_shader_key expect->n_texture_states * sizeof(enum compare_func))) return false; + if (memcmp(expect->image_format_conversion, have->image_format_conversion, + expect->n_images * sizeof(struct d3d12_image_format_conversion_info))) + return false; + if (expect->invert_depth != have->invert_depth) return false; @@ -804,6 +811,13 @@ d3d12_fill_shader_key(struct d3d12_selection_context *sel_ctx, sel_ctx->ctx->gfx_stages[PIPE_SHADER_GEOMETRY]->gs_key.has_front_face) { key->fs.remap_front_facing = 1; } + + key->n_images = sel_ctx->ctx->num_image_views[stage]; + for (int i = 0; i < key->n_images; ++i) { + key->image_format_conversion[i].emulated_format = sel_ctx->ctx->image_view_emulation_formats[stage][i]; + if (key->image_format_conversion[i].emulated_format != PIPE_FORMAT_NONE) + key->image_format_conversion[i].view_format = sel_ctx->ctx->image_views[stage][i].format; + } } static void @@ -887,6 +901,9 @@ select_shader_variant(struct d3d12_selection_context *sel_ctx, d3d12_shader_sele if (key.fs.cast_to_int) NIR_PASS_V(new_nir_variant, d3d12_lower_uint_cast, true); + if (key.n_images) + NIR_PASS_V(new_nir_variant, d3d12_lower_image_casts, key.image_format_conversion); + { struct nir_lower_tex_options tex_options = { }; tex_options.lower_txp = ~0u; /* No equivalent for textureProj */ diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.h b/src/gallium/drivers/d3d12/d3d12_compiler.h index 4947102ce95..6398aef4169 100644 --- a/src/gallium/drivers/d3d12/d3d12_compiler.h +++ b/src/gallium/drivers/d3d12/d3d12_compiler.h @@ -68,6 +68,10 @@ struct d3d12_varying_info { uint64_t mask; }; +struct d3d12_image_format_conversion_info { + enum pipe_format view_format, emulated_format; +}; + struct d3d12_shader_key { enum pipe_shader_type stage; @@ -114,6 +118,9 @@ struct d3d12_shader_key { dxil_wrap_sampler_state tex_wrap_states[PIPE_MAX_SHADER_SAMPLER_VIEWS]; dxil_texture_swizzle_state swizzle_state[PIPE_MAX_SHADER_SAMPLER_VIEWS]; enum compare_func sampler_compare_funcs[PIPE_MAX_SHADER_SAMPLER_VIEWS]; + + int n_images; + struct d3d12_image_format_conversion_info image_format_conversion[PIPE_MAX_SHADER_IMAGES]; }; struct d3d12_shader { diff --git a/src/gallium/drivers/d3d12/d3d12_context.cpp b/src/gallium/drivers/d3d12/d3d12_context.cpp index b1e6065c3cb..9fe45130a46 100644 --- a/src/gallium/drivers/d3d12/d3d12_context.cpp +++ b/src/gallium/drivers/d3d12/d3d12_context.cpp @@ -1536,6 +1536,47 @@ d3d12_increment_image_bind_count(struct d3d12_context *ctx, res->bind_counts[shader][D3D12_RESOURCE_BINDING_TYPE_IMAGE]++; } +static bool +is_valid_uav_cast(enum pipe_format resource_format, enum pipe_format view_format) +{ + if (view_format != PIPE_FORMAT_R32_UINT && + view_format != PIPE_FORMAT_R32_SINT && + view_format != PIPE_FORMAT_R32_FLOAT) + return false; + switch (d3d12_get_typeless_format(resource_format)) { + case DXGI_FORMAT_R8G8B8A8_TYPELESS: + case DXGI_FORMAT_B8G8R8A8_TYPELESS: + case DXGI_FORMAT_B8G8R8X8_TYPELESS: + case DXGI_FORMAT_R16G16_TYPELESS: + case DXGI_FORMAT_R10G10B10A2_TYPELESS: + return true; + default: + return false; + } +} + +static enum pipe_format +get_shader_image_emulation_format(enum pipe_format resource_format) +{ +#define CASE(f) case DXGI_FORMAT_##f##_TYPELESS: return PIPE_FORMAT_##f##_UINT + switch (d3d12_get_typeless_format(resource_format)) { + CASE(R8); + CASE(R8G8); + CASE(R8G8B8A8); + CASE(R16); + CASE(R16G16); + CASE(R16G16B16A16); + CASE(R32); + CASE(R32G32); + CASE(R32G32B32A32); + CASE(R10G10B10A2); + case DXGI_FORMAT_R11G11B10_FLOAT: + return PIPE_FORMAT_R11G11B10_FLOAT; + default: + unreachable("Unexpected shader image resource format"); + } +} + static void d3d12_set_shader_images(struct pipe_context *pctx, enum pipe_shader_type shader, @@ -1551,12 +1592,27 @@ d3d12_set_shader_images(struct pipe_context *pctx, pipe_resource_reference(&slot->resource, NULL); } + enum pipe_format emulation_format = PIPE_FORMAT_NONE; if (i < count && images && images[i].resource) { pipe_resource_reference(&slot->resource, images[i].resource); *slot = images[i]; d3d12_increment_image_bind_count(ctx, shader, d3d12_resource(images[i].resource)); + + if (images[i].resource->target != PIPE_BUFFER && + !is_valid_uav_cast(images[i].resource->format, images[i].format) && + d3d12_get_typeless_format(images[i].format) != + d3d12_get_typeless_format(images[i].resource->format)) { + /* Can't use D3D casting, have to use shader lowering instead */ + emulation_format = + get_shader_image_emulation_format(images[i].resource->format); + } } else memset(slot, 0, sizeof(*slot)); + + if (ctx->image_view_emulation_formats[shader][i] != emulation_format) { + ctx->image_view_emulation_formats[shader][i] = emulation_format; + ctx->state_dirty |= D3D12_DIRTY_SHADER; + } } if (images) { diff --git a/src/gallium/drivers/d3d12/d3d12_context.h b/src/gallium/drivers/d3d12/d3d12_context.h index fb34ae4af3c..36a02e00cc5 100644 --- a/src/gallium/drivers/d3d12/d3d12_context.h +++ b/src/gallium/drivers/d3d12/d3d12_context.h @@ -189,6 +189,7 @@ struct d3d12_context { struct pipe_shader_buffer ssbo_views[PIPE_SHADER_TYPES][PIPE_MAX_SHADER_BUFFERS]; unsigned num_ssbo_views[PIPE_SHADER_TYPES]; struct pipe_image_view image_views[PIPE_SHADER_TYPES][PIPE_MAX_SHADER_IMAGES]; + enum pipe_format image_view_emulation_formats[PIPE_SHADER_TYPES][PIPE_MAX_SHADER_IMAGES]; unsigned num_image_views[PIPE_SHADER_TYPES]; struct d3d12_sampler_state *samplers[PIPE_SHADER_TYPES][PIPE_MAX_SAMPLERS]; unsigned num_samplers[PIPE_SHADER_TYPES]; diff --git a/src/gallium/drivers/d3d12/d3d12_draw.cpp b/src/gallium/drivers/d3d12/d3d12_draw.cpp index 53194c0fb34..e12e08d445f 100644 --- a/src/gallium/drivers/d3d12/d3d12_draw.cpp +++ b/src/gallium/drivers/d3d12/d3d12_draw.cpp @@ -254,7 +254,10 @@ fill_image_descriptors(struct d3d12_context *ctx, uint64_t offset = 0; ID3D12Resource *d3d12_res = d3d12_resource_underlying(res, &offset); - uav_desc.Format = d3d12_get_format(res->base.b.format); + enum pipe_format view_format = ctx->image_view_emulation_formats[stage][i]; + if (view_format == PIPE_FORMAT_NONE) + view_format = view->format; + uav_desc.Format = d3d12_get_format(view_format); uav_desc.ViewDimension = image_view_dimension(res->base.b.target); unsigned array_size = view->u.tex.last_layer - view->u.tex.first_layer + 1; diff --git a/src/gallium/drivers/d3d12/d3d12_lower_image_casts.c b/src/gallium/drivers/d3d12/d3d12_lower_image_casts.c new file mode 100644 index 00000000000..b1a9d629727 --- /dev/null +++ b/src/gallium/drivers/d3d12/d3d12_lower_image_casts.c @@ -0,0 +1,261 @@ +/* + * Copyright © Microsoft Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice (including the next + * paragraph) shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#include "nir.h" +#include "nir_builder.h" +#include "nir_format_convert.h" + +#include "pipe/p_state.h" +#include "util/format/u_format.h" + +#include "d3d12_compiler.h" +#include "d3d12_nir_passes.h" + +static nir_ssa_def * +convert_value(nir_builder *b, nir_ssa_def *value, + const struct util_format_description *from_desc, + const struct util_format_description *to_desc) +{ + if (from_desc->format == to_desc->format) + return value; + + assert(value->num_components == 4); + /* No support for 16 or 64 bit data in the shader for image loads/stores */ + assert(value->bit_size == 32); + /* Overall format size needs to be the same */ + assert(from_desc->block.bits == to_desc->block.bits); + assert(from_desc->nr_channels <= 4 && to_desc->nr_channels <= 4); + + const unsigned rgba1010102_bits[] = { 10, 10, 10, 2 }; + + /* First, construct a "tightly packed" vector of the input values. For unorm/snorm, convert + * from the float we're given into the original bits (only happens while storing). For packed + * formats that don't fall on a nice bit size, convert/pack them into 32bit values. Otherwise, + * just produce a vecNx4 where N is the expected bit size. + */ + nir_ssa_def *src_as_vec; + if (from_desc->format == PIPE_FORMAT_R10G10B10A2_UINT || + from_desc->format == PIPE_FORMAT_R10G10B10A2_UNORM) { + if (from_desc->format == PIPE_FORMAT_R10G10B10A2_UNORM) + value = nir_format_float_to_unorm(b, value, rgba1010102_bits); + nir_ssa_def *channels[4]; + for (unsigned i = 0; i < 4; ++i) + channels[i] = nir_channel(b, value, i); + + src_as_vec = channels[0]; + src_as_vec = nir_mask_shift_or(b, src_as_vec, channels[1], (1 << 10) - 1, 10); + src_as_vec = nir_mask_shift_or(b, src_as_vec, channels[2], (1 << 10) - 1, 20); + src_as_vec = nir_mask_shift_or(b, src_as_vec, channels[3], (1 << 2) - 1, 30); + } else if (from_desc->format == PIPE_FORMAT_R11G11B10_FLOAT) { + src_as_vec = nir_format_pack_11f11f10f(b, value); + } else if (from_desc->is_unorm) { + if (from_desc->channel[0].size == 8) + src_as_vec = nir_pack_unorm_4x8(b, value); + else { + nir_ssa_def *packed_channels[2]; + packed_channels[0] = nir_pack_unorm_2x16(b, nir_channels(b, value, 0x3)); + packed_channels[1] = nir_pack_unorm_2x16(b, nir_channels(b, value, 0x3 << 2)); + src_as_vec = nir_vec(b, packed_channels, 2); + } + } else if (from_desc->is_snorm) { + if (from_desc->channel[0].size == 8) + src_as_vec = nir_pack_snorm_4x8(b, value); + else { + nir_ssa_def *packed_channels[2]; + packed_channels[0] = nir_pack_snorm_2x16(b, nir_channels(b, value, 0x3)); + packed_channels[1] = nir_pack_snorm_2x16(b, nir_channels(b, value, 0x3 << 2)); + src_as_vec = nir_vec(b, packed_channels, 2); + } + } else if (util_format_is_float(from_desc->format)) { + src_as_vec = nir_f2fN(b, value, from_desc->channel[0].size); + } else if (util_format_is_pure_sint(from_desc->format)) { + src_as_vec = nir_i2iN(b, value, from_desc->channel[0].size); + } else { + src_as_vec = nir_u2uN(b, value, from_desc->channel[0].size); + } + + /* Now that we have the tightly packed bits, we can use nir_extract_bits to get it into a + * vector of differently-sized components. For producing packed formats, get a 32-bit + * value and manually extract the bits. For unorm/snorm, get one or two 32-bit values, + * and extract it using helpers. Otherwise, get a format-sized dest vector and use a + * cast to expand it back to 32-bit. + * + * Pay extra attention for changing semantics for alpha as 1. + */ + if (to_desc->format == PIPE_FORMAT_R10G10B10A2_UINT || + to_desc->format == PIPE_FORMAT_R10G10B10A2_UNORM) { + nir_ssa_def *u32 = nir_extract_bits(b, &src_as_vec, 1, 0, 1, 32); + nir_ssa_def *channels[4] = { + nir_iand(b, u32, nir_imm_int(b, (1 << 10) - 1)), + nir_iand(b, nir_ushr(b, u32, nir_imm_int(b, 10)), nir_imm_int(b, (1 << 10) - 1)), + nir_iand(b, nir_ushr(b, u32, nir_imm_int(b, 20)), nir_imm_int(b, (1 << 10) - 1)), + nir_ushr(b, u32, nir_imm_int(b, 30)) + }; + nir_ssa_def *vec = nir_vec(b, channels, 4); + if (to_desc->format == PIPE_FORMAT_R10G10B10A2_UNORM) + vec = nir_format_unorm_to_float(b, vec, rgba1010102_bits); + return vec; + } else if (to_desc->format == PIPE_FORMAT_R11G11B10_FLOAT) { + nir_ssa_def *u32 = nir_extract_bits(b, &src_as_vec, 1, 0, 1, 32); + nir_ssa_def *vec3 = nir_format_unpack_11f11f10f(b, u32); + return nir_vec4(b, nir_channel(b, vec3, 0), + nir_channel(b, vec3, 1), + nir_channel(b, vec3, 2), + nir_imm_float(b, 1.0f)); + } else if (to_desc->is_unorm || to_desc->is_snorm) { + nir_ssa_def *dest_packed = nir_extract_bits(b, &src_as_vec, 1, 0, + DIV_ROUND_UP(to_desc->nr_channels * to_desc->channel[0].size, 32), 32); + if (to_desc->is_unorm) { + if (to_desc->channel[0].size == 8) { + nir_ssa_def *unpacked = nir_unpack_unorm_4x8(b, nir_channel(b, dest_packed, 0)); + if (to_desc->nr_channels < 4) + unpacked = nir_vector_insert_imm(b, unpacked, nir_imm_float(b, 1.0f), 3); + return unpacked; + } + nir_ssa_def *vec2s[2] = { + nir_unpack_unorm_2x16(b, nir_channel(b, dest_packed, 0)), + to_desc->nr_channels > 2 ? + nir_unpack_unorm_2x16(b, nir_channel(b, dest_packed, 1)) : + nir_vec2(b, nir_imm_float(b, 0.0f), nir_imm_float(b, 1.0f)) + }; + if (to_desc->nr_channels == 1) + vec2s[0] = nir_vector_insert_imm(b, vec2s[0], nir_imm_float(b, 0.0f), 1); + return nir_vec4(b, nir_channel(b, vec2s[0], 0), + nir_channel(b, vec2s[0], 1), + nir_channel(b, vec2s[1], 0), + nir_channel(b, vec2s[1], 1)); + } else { + if (to_desc->channel[0].size == 8) { + nir_ssa_def *unpacked = nir_unpack_snorm_4x8(b, nir_channel(b, dest_packed, 0)); + if (to_desc->nr_channels < 4) + unpacked = nir_vector_insert_imm(b, unpacked, nir_imm_float(b, 1.0f), 3); + return unpacked; + } + nir_ssa_def *vec2s[2] = { + nir_unpack_snorm_2x16(b, nir_channel(b, dest_packed, 0)), + to_desc->nr_channels > 2 ? + nir_unpack_snorm_2x16(b, nir_channel(b, dest_packed, 1)) : + nir_vec2(b, nir_imm_float(b, 0.0f), nir_imm_float(b, 1.0f)) + }; + if (to_desc->nr_channels == 1) + vec2s[0] = nir_vector_insert_imm(b, vec2s[0], nir_imm_float(b, 0.0f), 1); + return nir_vec4(b, nir_channel(b, vec2s[0], 0), + nir_channel(b, vec2s[0], 1), + nir_channel(b, vec2s[1], 0), + nir_channel(b, vec2s[1], 1)); + } + } else { + nir_ssa_def *dest_packed = nir_extract_bits(b, &src_as_vec, 1, 0, + to_desc->nr_channels, to_desc->channel[0].size); + nir_ssa_def *final_channels[4]; + for (unsigned i = 0; i < 4; ++i) { + if (i >= dest_packed->num_components) + final_channels[i] = util_format_is_float(to_desc->format) ? + nir_imm_floatN_t(b, i == 3 ? 1.0f : 0.0f, to_desc->channel[0].size) : + nir_imm_intN_t(b, i == 3 ? 1 : 0, to_desc->channel[0].size); + else + final_channels[i] = nir_channel(b, dest_packed, i); + } + nir_ssa_def *final_vec = nir_vec(b, final_channels, 4); + if (util_format_is_float(to_desc->format)) + return nir_f2f32(b, final_vec); + else if (util_format_is_pure_sint(to_desc->format)) + return nir_i2i32(b, final_vec); + else + return nir_u2u32(b, final_vec); + } +} + +static bool +lower_image_cast_instr(nir_builder *b, nir_instr *instr, void *_data) +{ + if (instr->type != nir_instr_type_intrinsic) + return false; + + nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); + if (intr->intrinsic != nir_intrinsic_image_deref_load && + intr->intrinsic != nir_intrinsic_image_deref_store) + return false; + + const struct d3d12_image_format_conversion_info *info = _data; + nir_variable *image = nir_deref_instr_get_variable(nir_src_as_deref(intr->src[0])); + assert(image); + + enum pipe_format emulation_format = info[image->data.driver_location].emulated_format; + if (emulation_format == PIPE_FORMAT_NONE) + return false; + + enum pipe_format real_format = info[image->data.driver_location].view_format; + assert(real_format != emulation_format); + + nir_ssa_def *value; + const struct util_format_description *from_desc, *to_desc; + if (intr->intrinsic == nir_intrinsic_image_deref_load) { + b->cursor = nir_after_instr(instr); + value = &intr->dest.ssa; + from_desc = util_format_description(emulation_format); + to_desc = util_format_description(real_format); + } else { + b->cursor = nir_before_instr(instr); + value = intr->src[3].ssa; + from_desc = util_format_description(real_format); + to_desc = util_format_description(emulation_format); + } + + nir_ssa_def *new_value = convert_value(b, value, from_desc, to_desc); + + nir_alu_type alu_type = util_format_is_pure_uint(emulation_format) ? + nir_type_uint : (util_format_is_pure_sint(emulation_format) ? + nir_type_int : nir_type_float); + + if (intr->intrinsic == nir_intrinsic_image_deref_load) { + nir_ssa_def_rewrite_uses_after(value, new_value, new_value->parent_instr); + nir_intrinsic_set_dest_type(intr, alu_type); + } else { + nir_instr_rewrite_src_ssa(instr, &intr->src[3], new_value); + nir_intrinsic_set_src_type(intr, alu_type); + } + nir_intrinsic_set_format(intr, emulation_format); + return true; +} + +/* Given a shader that does image loads/stores expecting to load from the format embedded in the intrinsic, + * if the corresponding entry in formats is not PIPE_FORMAT_NONE, replace the image format and convert + * the data being loaded/stored to/from the app's expected format. + */ +bool +d3d12_lower_image_casts(nir_shader *s, struct d3d12_image_format_conversion_info *info) +{ + bool progress = nir_shader_instructions_pass(s, lower_image_cast_instr, + nir_metadata_block_index | nir_metadata_dominance, info); + + if (progress) { + nir_foreach_image_variable(var, s) { + if (info[var->data.driver_location].emulated_format != PIPE_FORMAT_NONE) { + var->data.image.format = info[var->data.driver_location].emulated_format; + } + } + } + + return progress; +} diff --git a/src/gallium/drivers/d3d12/d3d12_nir_passes.h b/src/gallium/drivers/d3d12/d3d12_nir_passes.h index 4461939013e..54a9d2452fb 100644 --- a/src/gallium/drivers/d3d12/d3d12_nir_passes.h +++ b/src/gallium/drivers/d3d12/d3d12_nir_passes.h @@ -31,6 +31,7 @@ extern "C" { #endif struct d3d12_shader; +struct d3d12_image_format_conversion_info; bool d3d12_lower_point_sprite(nir_shader *shader, @@ -82,6 +83,9 @@ d3d12_lower_primitive_id(nir_shader *shader); void d3d12_lower_triangle_strip(nir_shader *shader); +bool +d3d12_lower_image_casts(nir_shader *s, struct d3d12_image_format_conversion_info *info); + #ifdef __cplusplus } #endif diff --git a/src/gallium/drivers/d3d12/d3d12_resource.cpp b/src/gallium/drivers/d3d12/d3d12_resource.cpp index 161844ce1b8..9c646e451ab 100644 --- a/src/gallium/drivers/d3d12/d3d12_resource.cpp +++ b/src/gallium/drivers/d3d12/d3d12_resource.cpp @@ -230,6 +230,7 @@ init_texture(struct d3d12_screen *screen, (support.Support2 & (D3D12_FORMAT_SUPPORT2_UAV_TYPED_LOAD | D3D12_FORMAT_SUPPORT2_UAV_TYPED_STORE)) == (D3D12_FORMAT_SUPPORT2_UAV_TYPED_LOAD | D3D12_FORMAT_SUPPORT2_UAV_TYPED_STORE)) { desc.Flags |= D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS; + desc.Format = d3d12_get_typeless_format(templ->format); } } diff --git a/src/gallium/drivers/d3d12/meson.build b/src/gallium/drivers/d3d12/meson.build index 6132654863b..2dc44a082cd 100644 --- a/src/gallium/drivers/d3d12/meson.build +++ b/src/gallium/drivers/d3d12/meson.build @@ -31,6 +31,7 @@ files_libd3d12 = files( 'd3d12_fence.cpp', 'd3d12_format.c', 'd3d12_gs_variant.cpp', + 'd3d12_lower_image_casts.c', 'd3d12_lower_int_cubemap_to_array.c', 'd3d12_lower_point_sprite.c', 'd3d12_nir_lower_texcmp.c',