From 5d9ef0efb57781f132bd8b47a43bde5e0d13baf8 Mon Sep 17 00:00:00 2001 From: Konstantin Seurer Date: Thu, 7 Apr 2022 15:39:52 +0200 Subject: [PATCH] radv: Add the fuchsia radix sort Signed-off-by: Konstantin Seurer Reviewed-by: Bas Nieuwenhuizen Part-of: --- meson.build | 2 +- src/amd/vulkan/meson.build | 5 +- src/amd/vulkan/radix_sort/LICENSE | 24 + src/amd/vulkan/radix_sort/common/macros.h | 112 ++ src/amd/vulkan/radix_sort/common/util.c | 90 + src/amd/vulkan/radix_sort/common/util.h | 59 + src/amd/vulkan/radix_sort/common/vk/assert.c | 108 ++ src/amd/vulkan/radix_sort/common/vk/assert.h | 52 + src/amd/vulkan/radix_sort/common/vk/barrier.c | 305 +++ src/amd/vulkan/radix_sort/common/vk/barrier.h | 72 + src/amd/vulkan/radix_sort/meson.build | 40 + src/amd/vulkan/radix_sort/radix_sort_vk.c | 1240 ++++++++++++ src/amd/vulkan/radix_sort/radix_sort_vk.h | 384 ++++ .../vulkan/radix_sort/radix_sort_vk_devaddr.h | 104 + src/amd/vulkan/radix_sort/radix_sort_vk_ext.h | 77 + src/amd/vulkan/radix_sort/radv_radix_sort.c | 193 ++ src/amd/vulkan/radix_sort/radv_radix_sort.h | 32 + src/amd/vulkan/radix_sort/shaders/bufref.h | 151 ++ src/amd/vulkan/radix_sort/shaders/fill.comp | 143 ++ .../vulkan/radix_sort/shaders/histogram.comp | 449 +++++ src/amd/vulkan/radix_sort/shaders/init.comp | 168 ++ src/amd/vulkan/radix_sort/shaders/meson.build | 51 + src/amd/vulkan/radix_sort/shaders/prefix.comp | 194 ++ src/amd/vulkan/radix_sort/shaders/prefix.h | 353 ++++ .../vulkan/radix_sort/shaders/prefix_limits.h | 48 + src/amd/vulkan/radix_sort/shaders/push.h | 263 +++ .../vulkan/radix_sort/shaders/scatter.glsl | 1706 +++++++++++++++++ .../radix_sort/shaders/scatter_0_even.comp | 36 + .../radix_sort/shaders/scatter_0_odd.comp | 36 + .../radix_sort/shaders/scatter_1_even.comp | 36 + .../radix_sort/shaders/scatter_1_odd.comp | 36 + src/amd/vulkan/radix_sort/target.h | 57 + .../vulkan/radix_sort/targets/u64/config.h | 34 + 33 files changed, 6658 insertions(+), 2 deletions(-) create mode 100644 src/amd/vulkan/radix_sort/LICENSE create mode 100644 src/amd/vulkan/radix_sort/common/macros.h create mode 100644 src/amd/vulkan/radix_sort/common/util.c create mode 100644 src/amd/vulkan/radix_sort/common/util.h create mode 100644 src/amd/vulkan/radix_sort/common/vk/assert.c create mode 100644 src/amd/vulkan/radix_sort/common/vk/assert.h create mode 100644 src/amd/vulkan/radix_sort/common/vk/barrier.c create mode 100644 src/amd/vulkan/radix_sort/common/vk/barrier.h create mode 100644 src/amd/vulkan/radix_sort/meson.build create mode 100644 src/amd/vulkan/radix_sort/radix_sort_vk.c create mode 100644 src/amd/vulkan/radix_sort/radix_sort_vk.h create mode 100644 src/amd/vulkan/radix_sort/radix_sort_vk_devaddr.h create mode 100644 src/amd/vulkan/radix_sort/radix_sort_vk_ext.h create mode 100644 src/amd/vulkan/radix_sort/radv_radix_sort.c create mode 100644 src/amd/vulkan/radix_sort/radv_radix_sort.h create mode 100644 src/amd/vulkan/radix_sort/shaders/bufref.h create mode 100644 src/amd/vulkan/radix_sort/shaders/fill.comp create mode 100644 src/amd/vulkan/radix_sort/shaders/histogram.comp create mode 100644 src/amd/vulkan/radix_sort/shaders/init.comp create mode 100644 src/amd/vulkan/radix_sort/shaders/meson.build create mode 100644 src/amd/vulkan/radix_sort/shaders/prefix.comp create mode 100644 src/amd/vulkan/radix_sort/shaders/prefix.h create mode 100644 src/amd/vulkan/radix_sort/shaders/prefix_limits.h create mode 100644 src/amd/vulkan/radix_sort/shaders/push.h create mode 100644 src/amd/vulkan/radix_sort/shaders/scatter.glsl create mode 100644 src/amd/vulkan/radix_sort/shaders/scatter_0_even.comp create mode 100644 src/amd/vulkan/radix_sort/shaders/scatter_0_odd.comp create mode 100644 src/amd/vulkan/radix_sort/shaders/scatter_1_even.comp create mode 100644 src/amd/vulkan/radix_sort/shaders/scatter_1_odd.comp create mode 100644 src/amd/vulkan/radix_sort/target.h create mode 100644 src/amd/vulkan/radix_sort/targets/u64/config.h diff --git a/meson.build b/meson.build index 8f71422803b..1c7c6beeda8 100644 --- a/meson.build +++ b/meson.build @@ -680,7 +680,7 @@ if with_gallium_d3d12 or with_microsoft_clc or with_microsoft_vk endif endif -if with_vulkan_overlay_layer or with_aco_tests +if with_vulkan_overlay_layer or with_aco_tests or with_amd_vk prog_glslang = find_program('glslangValidator') endif diff --git a/src/amd/vulkan/meson.build b/src/amd/vulkan/meson.build index b6ecd999e71..2ad24c29520 100644 --- a/src/amd/vulkan/meson.build +++ b/src/amd/vulkan/meson.build @@ -120,6 +120,9 @@ if with_llvm ) endif +subdir('radix_sort') +libradv_files += radix_sort_files + radv_deps = [] radv_flags = cc.get_supported_arguments(['-Wimplicit-fallthrough', '-Wshadow']) @@ -151,7 +154,7 @@ endif libvulkan_radeon = shared_library( 'vulkan_radeon', - [libradv_files, radv_entrypoints, sha1_h], + [libradv_files, radv_entrypoints, sha1_h, radix_sort_spv], vs_module_defs : vulkan_api_def, include_directories : [ inc_include, inc_src, inc_mapi, inc_mesa, inc_gallium, inc_gallium_aux, inc_amd, inc_amd_common, inc_amd_common_llvm, inc_compiler, inc_util, diff --git a/src/amd/vulkan/radix_sort/LICENSE b/src/amd/vulkan/radix_sort/LICENSE new file mode 100644 index 00000000000..7ed244f42dc --- /dev/null +++ b/src/amd/vulkan/radix_sort/LICENSE @@ -0,0 +1,24 @@ +Copyright 2019 The Fuchsia Authors. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/src/amd/vulkan/radix_sort/common/macros.h b/src/amd/vulkan/radix_sort/common/macros.h new file mode 100644 index 00000000000..475d4e426c1 --- /dev/null +++ b/src/amd/vulkan/radix_sort/common/macros.h @@ -0,0 +1,112 @@ +// Copyright 2019 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SRC_GRAPHICS_LIB_COMPUTE_COMMON_MACROS_H_ +#define SRC_GRAPHICS_LIB_COMPUTE_COMMON_MACROS_H_ + +// +// +// + +#include +#include +#include + +// +// clang-format off +// + +#define ARRAY_LENGTH_MACRO(x_) (sizeof(x_)/sizeof(x_[0])) +#define OFFSETOF_MACRO(t_,m_) offsetof(t_,m_) +#define MEMBER_SIZE_MACRO(t_,m_) sizeof(((t_*)0)->m_) + +// +// FIXME(allanmac): +// +// Consider providing typed min/max() functions: +// +// [min|max]_(a,b) { ; } +// +// But note we still need preprocessor-time min/max(). +// + +#define MAX_MACRO(t_,a_,b_) (((a_) > (b_)) ? (a_) : (b_)) +#define MIN_MACRO(t_,a_,b_) (((a_) < (b_)) ? (a_) : (b_)) + +// +// +// + +#define BITS_TO_MASK_MACRO(n_) (((uint32_t)1<<(n_))-1) +#define BITS_TO_MASK_64_MACRO(n_) (((uint64_t)1<<(n_))-1) + +#define BITS_TO_MASK_AT_MACRO(n_,b_) (BITS_TO_MASK_MACRO(n_) <<(b_)) +#define BITS_TO_MASK_AT_64_MACRO(n_,b_) (BITS_TO_MASK_64_MACRO(n_)<<(b_)) + +// +// +// + +#define STRINGIFY_MACRO_2(a_) #a_ +#define STRINGIFY_MACRO(a_) STRINGIFY_MACRO_2(a_) + +// +// +// + +#define CONCAT_MACRO_2(a_,b_) a_ ## b_ +#define CONCAT_MACRO(a_,b_) CONCAT_MACRO_2(a_,b_) + +// +// Round up/down +// + +#define ROUND_DOWN_MACRO(v_,q_) (((v_) / (q_)) * (q_)) +#define ROUND_UP_MACRO(v_,q_) ((((v_) + (q_) - 1) / (q_)) * (q_)) + +// +// Round up/down when q is a power-of-two. +// + +#define ROUND_DOWN_POW2_MACRO(v_,q_) ((v_) & ~((q_) - 1)) +#define ROUND_UP_POW2_MACRO(v_,q_) ROUND_DOWN_POW2_MACRO((v_) + (q_) - 1, q_) + +// +// +// + +#if defined (_MSC_VER) && !defined (__clang__) +#define STATIC_ASSERT_MACRO(c_,m_) static_assert(c_,m_) +#else +#define STATIC_ASSERT_MACRO(c_,m_) _Static_assert(c_,m_) +#endif + +#define STATIC_ASSERT_MACRO_1(c_) STATIC_ASSERT_MACRO(c_,#c_) + +// +// +// + +#if defined (_MSC_VER) && !defined (__clang__) +#define POPCOUNT_MACRO(...) __popcnt(__VA_ARGS__) +#else +#define POPCOUNT_MACRO(...) __builtin_popcount(__VA_ARGS__) +#endif + +// +// +// + +#if defined (_MSC_VER) && !defined (__clang__) +#define ALIGN_MACRO(bytes_) __declspec(align(bytes_)) // only accepts integer as arg +#else +#include +#define ALIGN_MACRO(bytes_) alignas(bytes_) +#endif + +// +// clang-format on +// + +#endif // SRC_GRAPHICS_LIB_COMPUTE_COMMON_MACROS_H_ diff --git a/src/amd/vulkan/radix_sort/common/util.c b/src/amd/vulkan/radix_sort/common/util.c new file mode 100644 index 00000000000..2ed2425d760 --- /dev/null +++ b/src/amd/vulkan/radix_sort/common/util.c @@ -0,0 +1,90 @@ +// Copyright 2019 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "util.h" + +#include + +// +// +// + +#if defined(_MSC_VER) && !defined(__clang__) + +#include + +#endif + +// +// +// + +bool +is_pow2_u32(uint32_t n) +{ + return n && !(n & (n - 1)); +} + +// +// +// + +uint32_t +pow2_ru_u32(uint32_t n) +{ + assert(n <= 0x80000000U); + + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + n++; + + return n; +} + +// +// +// + +uint32_t +pow2_rd_u32(uint32_t n) +{ + assert(n > 0); + + return 1u << msb_idx_u32(n); +} + +// +// ASSUMES NON-ZERO +// + +uint32_t +msb_idx_u32(uint32_t n) +{ + assert(n > 0); +#if defined(_MSC_VER) && !defined(__clang__) + + uint32_t index; + + _BitScanReverse((unsigned long *)&index, n); + + return index; + +#elif defined(__GNUC__) + + return __builtin_clz(n) ^ 31; + +#else + +#error "No msb_index()" + +#endif +} + +// +// +// diff --git a/src/amd/vulkan/radix_sort/common/util.h b/src/amd/vulkan/radix_sort/common/util.h new file mode 100644 index 00000000000..815231eb7aa --- /dev/null +++ b/src/amd/vulkan/radix_sort/common/util.h @@ -0,0 +1,59 @@ +// Copyright 2019 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SRC_GRAPHICS_LIB_COMPUTE_COMMON_UTIL_H_ +#define SRC_GRAPHICS_LIB_COMPUTE_COMMON_UTIL_H_ + +// +// +// + +#include +#include + +// +// +// + +#ifdef __cplusplus +extern "C" { +#endif + +// +// +// + +// Return true iff |n| is a power of 2. +bool +is_pow2_u32(uint32_t n); + +// Return |n| rounded-up to the nearest power of 2. +// If |n| is zero then return 0. +// REQUIRES: |n <= 0x80000000|. +uint32_t +pow2_ru_u32(uint32_t n); + +// Return |n| rounded-down to the nearest power of 2. +// REQUIRES: |n > 0|. +uint32_t +pow2_rd_u32(uint32_t n); + +// Return the most-significant bit position for |n|. +// REQUIRES: |n > 0|. +uint32_t +msb_idx_u32(uint32_t n); // 0-based bit position + +// +// +// + +#ifdef __cplusplus +} +#endif + +// +// +// + +#endif // SRC_GRAPHICS_LIB_COMPUTE_COMMON_UTIL_H_ diff --git a/src/amd/vulkan/radix_sort/common/vk/assert.c b/src/amd/vulkan/radix_sort/common/vk/assert.c new file mode 100644 index 00000000000..d5d5d07b454 --- /dev/null +++ b/src/amd/vulkan/radix_sort/common/vk/assert.c @@ -0,0 +1,108 @@ +// Copyright 2019 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// +// +// + +#include +#include + +// +// +// + +#include "assert.h" + +// +// +// + +#define VK_RESULT_TO_STRING(result) \ + case result: \ + return #result + +// +// FIXME -- results and errors +// + +char const * +vk_get_result_string(VkResult const result) +{ + switch (result) + { + // + // Results + // + VK_RESULT_TO_STRING(VK_SUCCESS); + VK_RESULT_TO_STRING(VK_NOT_READY); + VK_RESULT_TO_STRING(VK_TIMEOUT); + VK_RESULT_TO_STRING(VK_EVENT_SET); + VK_RESULT_TO_STRING(VK_EVENT_RESET); + VK_RESULT_TO_STRING(VK_INCOMPLETE); + // + // Errors + // + VK_RESULT_TO_STRING(VK_ERROR_OUT_OF_HOST_MEMORY); + VK_RESULT_TO_STRING(VK_ERROR_OUT_OF_DEVICE_MEMORY); + VK_RESULT_TO_STRING(VK_ERROR_INITIALIZATION_FAILED); + VK_RESULT_TO_STRING(VK_ERROR_DEVICE_LOST); + VK_RESULT_TO_STRING(VK_ERROR_MEMORY_MAP_FAILED); + VK_RESULT_TO_STRING(VK_ERROR_LAYER_NOT_PRESENT); + VK_RESULT_TO_STRING(VK_ERROR_EXTENSION_NOT_PRESENT); + VK_RESULT_TO_STRING(VK_ERROR_FEATURE_NOT_PRESENT); + VK_RESULT_TO_STRING(VK_ERROR_INCOMPATIBLE_DRIVER); + VK_RESULT_TO_STRING(VK_ERROR_TOO_MANY_OBJECTS); + VK_RESULT_TO_STRING(VK_ERROR_FORMAT_NOT_SUPPORTED); + VK_RESULT_TO_STRING(VK_ERROR_FRAGMENTED_POOL); + VK_RESULT_TO_STRING(VK_ERROR_OUT_OF_POOL_MEMORY); + VK_RESULT_TO_STRING(VK_ERROR_INVALID_EXTERNAL_HANDLE); + VK_RESULT_TO_STRING(VK_ERROR_SURFACE_LOST_KHR); + VK_RESULT_TO_STRING(VK_ERROR_NATIVE_WINDOW_IN_USE_KHR); + VK_RESULT_TO_STRING(VK_SUBOPTIMAL_KHR); + VK_RESULT_TO_STRING(VK_ERROR_OUT_OF_DATE_KHR); + VK_RESULT_TO_STRING(VK_ERROR_INCOMPATIBLE_DISPLAY_KHR); + VK_RESULT_TO_STRING(VK_ERROR_VALIDATION_FAILED_EXT); + VK_RESULT_TO_STRING(VK_ERROR_INVALID_SHADER_NV); + VK_RESULT_TO_STRING(VK_ERROR_FRAGMENTATION_EXT); + VK_RESULT_TO_STRING(VK_ERROR_NOT_PERMITTED_EXT); + + // + // Extensions: vk_xyz + // + default: + return "UNKNOWN VULKAN RESULT"; + } +} + +// +// +// + +VkResult +vk_assert(VkResult const result, char const * const file, int const line, bool const is_abort) +{ + if (result != VK_SUCCESS) + { + char const * const vk_result_str = vk_get_result_string(result); + + fprintf(stderr, + "\"%s\", line %d: vk_assert( %d ) = \"%s\"\n", + file, + line, + result, + vk_result_str); + + if (is_abort) + { + abort(); + } + } + + return result; +} + +// +// +// diff --git a/src/amd/vulkan/radix_sort/common/vk/assert.h b/src/amd/vulkan/radix_sort/common/vk/assert.h new file mode 100644 index 00000000000..9d3fe43ed15 --- /dev/null +++ b/src/amd/vulkan/radix_sort/common/vk/assert.h @@ -0,0 +1,52 @@ +// Copyright 2019 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SRC_GRAPHICS_LIB_COMPUTE_COMMON_VK_ASSERT_H_ +#define SRC_GRAPHICS_LIB_COMPUTE_COMMON_VK_ASSERT_H_ + +// +// +// + +#include +#include + +// +// +// + +#ifdef __cplusplus +extern "C" { +#endif + +// +// +// + +char const * +vk_get_result_string(VkResult const result); + +VkResult +vk_assert(VkResult const result, char const * const file, int const line, bool const is_abort); + +// +// clang-format off +// + +#define vk(...) vk_assert((vk##__VA_ARGS__), __FILE__, __LINE__, true); +#define vk_ok(err) vk_assert(err, __FILE__, __LINE__, true); + +// +// clang-format on +// + +#ifdef __cplusplus +} +#endif + +// +// +// + +#endif // SRC_GRAPHICS_LIB_COMPUTE_COMMON_VK_ASSERT_H_ diff --git a/src/amd/vulkan/radix_sort/common/vk/barrier.c b/src/amd/vulkan/radix_sort/common/vk/barrier.c new file mode 100644 index 00000000000..58134dbd11a --- /dev/null +++ b/src/amd/vulkan/radix_sort/common/vk/barrier.c @@ -0,0 +1,305 @@ +// Copyright 2019 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// +// +// + +#include "barrier.h" + +// +// +// + +void +vk_barrier_compute_w_to_compute_r(VkCommandBuffer cb) +{ + static VkMemoryBarrier const mb = { + + .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER, + .pNext = NULL, + .srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT, + .dstAccessMask = VK_ACCESS_SHADER_READ_BIT + }; + + vkCmdPipelineBarrier(cb, + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + 0, + 1, + &mb, + 0, + NULL, + 0, + NULL); +} + +// +// +// + +void +vk_barrier_compute_w_to_transfer_r(VkCommandBuffer cb) +{ + static VkMemoryBarrier const mb = { + + .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER, + .pNext = NULL, + .srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT, + .dstAccessMask = VK_ACCESS_TRANSFER_READ_BIT + }; + + vkCmdPipelineBarrier(cb, + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT, + 0, + 1, + &mb, + 0, + NULL, + 0, + NULL); +} + +// +// +// + +void +vk_barrier_transfer_w_to_compute_r(VkCommandBuffer cb) +{ + static VkMemoryBarrier const mb = { + + .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER, + .pNext = NULL, + .srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT, + .dstAccessMask = VK_ACCESS_SHADER_READ_BIT + }; + + vkCmdPipelineBarrier(cb, + VK_PIPELINE_STAGE_TRANSFER_BIT, + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + 0, + 1, + &mb, + 0, + NULL, + 0, + NULL); +} + +// +// +// + +void +vk_barrier_transfer_w_to_compute_w(VkCommandBuffer cb) +{ + static VkMemoryBarrier const mb = { + + .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER, + .pNext = NULL, + .srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT, + .dstAccessMask = VK_ACCESS_SHADER_WRITE_BIT + }; + + vkCmdPipelineBarrier(cb, + VK_PIPELINE_STAGE_TRANSFER_BIT, + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + 0, + 1, + &mb, + 0, + NULL, + 0, + NULL); +} + +// +// +// + +void +vk_barrier_compute_w_to_indirect_compute_r(VkCommandBuffer cb) +{ + static VkMemoryBarrier const mb = { + + .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER, + .pNext = NULL, + .srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT, + .dstAccessMask = VK_ACCESS_INDIRECT_COMMAND_READ_BIT | // + VK_ACCESS_SHADER_READ_BIT + }; + + vkCmdPipelineBarrier(cb, + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_DRAW_INDIRECT_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + 0, + 1, + &mb, + 0, + NULL, + 0, + NULL); +} + +// +// +// + +void +vk_barrier_transfer_w_compute_w_to_transfer_r(VkCommandBuffer cb) +{ + static VkMemoryBarrier const mb = { + + .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER, + .pNext = NULL, + .srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT | // + VK_ACCESS_SHADER_WRITE_BIT, + .dstAccessMask = VK_ACCESS_TRANSFER_READ_BIT + }; + + vkCmdPipelineBarrier(cb, + VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT, + 0, + 1, + &mb, + 0, + NULL, + 0, + NULL); +} + +// +// +// + +void +vk_barrier_compute_w_to_host_r(VkCommandBuffer cb) +{ + static VkMemoryBarrier const mb = { + + .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER, + .pNext = NULL, + .srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT, + .dstAccessMask = VK_ACCESS_HOST_READ_BIT + }; + + vkCmdPipelineBarrier(cb, + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_PIPELINE_STAGE_HOST_BIT, + 0, + 1, + &mb, + 0, + NULL, + 0, + NULL); +} + +// +// +// + +void +vk_barrier_transfer_w_to_host_r(VkCommandBuffer cb) +{ + static VkMemoryBarrier const mb = { + + .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER, + .pNext = NULL, + .srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT, + .dstAccessMask = VK_ACCESS_HOST_READ_BIT + }; + + vkCmdPipelineBarrier(cb, + VK_PIPELINE_STAGE_TRANSFER_BIT, + VK_PIPELINE_STAGE_HOST_BIT, + 0, + 1, + &mb, + 0, + NULL, + 0, + NULL); +} + +// +// +// + +void +vk_memory_barrier(VkCommandBuffer cb, + VkPipelineStageFlags src_stage, + VkAccessFlags src_mask, + VkPipelineStageFlags dst_stage, + VkAccessFlags dst_mask) +{ + VkMemoryBarrier const mb = { .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER, + .pNext = NULL, + .srcAccessMask = src_mask, + .dstAccessMask = dst_mask }; + + vkCmdPipelineBarrier(cb, src_stage, dst_stage, 0, 1, &mb, 0, NULL, 0, NULL); +} + +// +// +// + +void +vk_barrier_debug(VkCommandBuffer cb) +{ + static VkMemoryBarrier const mb = { + + .sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER, + .pNext = NULL, + .srcAccessMask = VK_ACCESS_INDIRECT_COMMAND_READ_BIT | // + VK_ACCESS_INDEX_READ_BIT | // + VK_ACCESS_VERTEX_ATTRIBUTE_READ_BIT | // + VK_ACCESS_UNIFORM_READ_BIT | // + VK_ACCESS_INPUT_ATTACHMENT_READ_BIT | // + VK_ACCESS_SHADER_READ_BIT | // + VK_ACCESS_SHADER_WRITE_BIT | // + VK_ACCESS_COLOR_ATTACHMENT_READ_BIT | // + VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT | // + VK_ACCESS_DEPTH_STENCIL_ATTACHMENT_READ_BIT | // + VK_ACCESS_DEPTH_STENCIL_ATTACHMENT_WRITE_BIT | // + VK_ACCESS_TRANSFER_READ_BIT | // + VK_ACCESS_TRANSFER_WRITE_BIT | // + VK_ACCESS_HOST_READ_BIT | // + VK_ACCESS_HOST_WRITE_BIT, + .dstAccessMask = VK_ACCESS_INDIRECT_COMMAND_READ_BIT | // + VK_ACCESS_INDEX_READ_BIT | // + VK_ACCESS_VERTEX_ATTRIBUTE_READ_BIT | // + VK_ACCESS_UNIFORM_READ_BIT | // + VK_ACCESS_INPUT_ATTACHMENT_READ_BIT | // + VK_ACCESS_SHADER_READ_BIT | // + VK_ACCESS_SHADER_WRITE_BIT | // + VK_ACCESS_COLOR_ATTACHMENT_READ_BIT | // + VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT | // + VK_ACCESS_DEPTH_STENCIL_ATTACHMENT_READ_BIT | // + VK_ACCESS_DEPTH_STENCIL_ATTACHMENT_WRITE_BIT | // + VK_ACCESS_TRANSFER_READ_BIT | // + VK_ACCESS_TRANSFER_WRITE_BIT | // + VK_ACCESS_HOST_READ_BIT | // + VK_ACCESS_HOST_WRITE_BIT + }; + + vkCmdPipelineBarrier(cb, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, + 0, + 1, + &mb, + 0, + NULL, + 0, + NULL); +} + +// +// +// diff --git a/src/amd/vulkan/radix_sort/common/vk/barrier.h b/src/amd/vulkan/radix_sort/common/vk/barrier.h new file mode 100644 index 00000000000..cda1852dbd1 --- /dev/null +++ b/src/amd/vulkan/radix_sort/common/vk/barrier.h @@ -0,0 +1,72 @@ +// Copyright 2019 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SRC_GRAPHICS_LIB_COMPUTE_COMMON_VK_BARRIER_H_ +#define SRC_GRAPHICS_LIB_COMPUTE_COMMON_VK_BARRIER_H_ + +// +// +// + +#include + +// +// +// + +#ifdef __cplusplus +extern "C" { +#endif + +// +// +// + +void +vk_barrier_compute_w_to_compute_r(VkCommandBuffer cb); + +void +vk_barrier_compute_w_to_transfer_r(VkCommandBuffer cb); + +void +vk_barrier_transfer_w_to_compute_r(VkCommandBuffer cb); + +void +vk_barrier_transfer_w_to_compute_w(VkCommandBuffer cb); + +void +vk_barrier_compute_w_to_indirect_compute_r(VkCommandBuffer cb); + +void +vk_barrier_transfer_w_compute_w_to_transfer_r(VkCommandBuffer cb); + +void +vk_barrier_compute_w_to_host_r(VkCommandBuffer cb); + +void +vk_barrier_transfer_w_to_host_r(VkCommandBuffer cb); + +void +vk_memory_barrier(VkCommandBuffer cb, + VkPipelineStageFlags src_stage, + VkAccessFlags src_mask, + VkPipelineStageFlags dst_stage, + VkAccessFlags dst_mask); + +void +vk_barrier_debug(VkCommandBuffer cb); + +// +// +// + +#ifdef __cplusplus +} +#endif + +// +// +// + +#endif // SRC_GRAPHICS_LIB_COMPUTE_COMMON_VK_BARRIER_H_ diff --git a/src/amd/vulkan/radix_sort/meson.build b/src/amd/vulkan/radix_sort/meson.build new file mode 100644 index 00000000000..46c83847090 --- /dev/null +++ b/src/amd/vulkan/radix_sort/meson.build @@ -0,0 +1,40 @@ +# Copyright © 2022 Konstantin Seurer + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +subdir('shaders') + +radix_sort_files = files( + 'common/vk/assert.c', + 'common/vk/assert.h', + 'common/vk/barrier.c', + 'common/vk/barrier.h', + 'common/macros.h', + 'common/util.c', + 'common/util.h', + 'shaders/push.h', + 'targets/u64/config.h', + 'radix_sort_vk_devaddr.h', + 'radix_sort_vk_ext.h', + 'radix_sort_vk.c', + 'radix_sort_vk.h', + 'radv_radix_sort.c', + 'radv_radix_sort.h', + 'target.h' +) diff --git a/src/amd/vulkan/radix_sort/radix_sort_vk.c b/src/amd/vulkan/radix_sort/radix_sort_vk.c new file mode 100644 index 00000000000..e8be05979d5 --- /dev/null +++ b/src/amd/vulkan/radix_sort/radix_sort_vk.c @@ -0,0 +1,1240 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include +#include + +#include "common/macros.h" +#include "common/util.h" +#include "common/vk/assert.h" +#include "common/vk/barrier.h" +#include "radix_sort_vk_devaddr.h" +#include "shaders/push.h" + +// +// +// + +#ifdef RS_VK_ENABLE_DEBUG_UTILS +#include "common/vk/debug_utils.h" +#endif + +// +// +// + +#ifdef RS_VK_ENABLE_EXTENSIONS +#include "radix_sort_vk_ext.h" +#endif + +// +// FIXME(allanmac): memoize some of these calculations +// +void +radix_sort_vk_get_memory_requirements(radix_sort_vk_t const * rs, + uint32_t count, + radix_sort_vk_memory_requirements_t * mr) +{ + // + // Keyval size + // + mr->keyval_size = rs->config.keyval_dwords * sizeof(uint32_t); + + // + // Subgroup and workgroup sizes + // + uint32_t const histo_sg_size = 1 << rs->config.histogram.subgroup_size_log2; + uint32_t const histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2; + uint32_t const prefix_sg_size = 1 << rs->config.prefix.subgroup_size_log2; + uint32_t const scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2; + uint32_t const internal_sg_size = MAX_MACRO(uint32_t, histo_sg_size, prefix_sg_size); + + // + // If for some reason count is zero then initialize appropriately. + // + if (count == 0) + { + mr->keyvals_size = 0; + mr->keyvals_alignment = mr->keyval_size * histo_sg_size; + mr->internal_size = 0; + mr->internal_alignment = internal_sg_size * sizeof(uint32_t); + mr->indirect_size = 0; + mr->indirect_alignment = internal_sg_size * sizeof(uint32_t); + } + else + { + // + // Keyvals + // + + // Round up to the scatter block size. + // + // Then round up to the histogram block size. + // + // Fill the difference between this new count and the original keyval + // count. + // + // How many scatter blocks? + // + uint32_t const scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows; + uint32_t const scatter_blocks = (count + scatter_block_kvs - 1) / scatter_block_kvs; + uint32_t const count_ru_scatter = scatter_blocks * scatter_block_kvs; + + // + // How many histogram blocks? + // + // Note that it's OK to have more max-valued digits counted by the histogram + // than sorted by the scatters because the sort is stable. + // + uint32_t const histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows; + uint32_t const histo_blocks = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs; + uint32_t const count_ru_histo = histo_blocks * histo_block_kvs; + + mr->keyvals_size = mr->keyval_size * count_ru_histo; + mr->keyvals_alignment = mr->keyval_size * histo_sg_size; + + // + // Internal + // + // NOTE: Assumes .histograms are before .partitions. + // + // Last scatter workgroup skips writing to a partition. + // + // One histogram per (keyval byte + partitions) + // + uint32_t const partitions = scatter_blocks - 1; + + mr->internal_size = (mr->keyval_size + partitions) * (RS_RADIX_SIZE * sizeof(uint32_t)); + mr->internal_alignment = internal_sg_size * sizeof(uint32_t); + + // + // Indirect + // + mr->indirect_size = sizeof(struct rs_indirect_info); + mr->indirect_alignment = sizeof(struct u32vec4); + } +} + +// +// +// +#ifdef RS_VK_ENABLE_DEBUG_UTILS + +static void +rs_debug_utils_set(VkDevice device, struct radix_sort_vk * rs) +{ + if (pfn_vkSetDebugUtilsObjectNameEXT != NULL) + { + VkDebugUtilsObjectNameInfoEXT duoni = { + .sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_OBJECT_NAME_INFO_EXT, + .pNext = NULL, + .objectType = VK_OBJECT_TYPE_PIPELINE, + }; + + duoni.objectHandle = (uint64_t)rs->pipelines.named.init; + duoni.pObjectName = "radix_sort_init"; + vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); + + duoni.objectHandle = (uint64_t)rs->pipelines.named.fill; + duoni.pObjectName = "radix_sort_fill"; + vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); + + duoni.objectHandle = (uint64_t)rs->pipelines.named.histogram; + duoni.pObjectName = "radix_sort_histogram"; + vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); + + duoni.objectHandle = (uint64_t)rs->pipelines.named.prefix; + duoni.pObjectName = "radix_sort_prefix"; + vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); + + duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[0].even; + duoni.pObjectName = "radix_sort_scatter_0_even"; + vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); + + duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[0].odd; + duoni.pObjectName = "radix_sort_scatter_0_odd"; + vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); + + if (rs->config.keyval_dwords >= 2) + { + duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[1].even; + duoni.pObjectName = "radix_sort_scatter_1_even"; + vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); + + duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[1].odd; + duoni.pObjectName = "radix_sort_scatter_1_odd"; + vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni)); + } + } +} + +#endif + +// +// How many pipelines are there? +// +static uint32_t +rs_pipeline_count(struct radix_sort_vk const * rs) +{ + return 1 + // init + 1 + // fill + 1 + // histogram + 1 + // prefix + 2 * rs->config.keyval_dwords; // scatters.even/odd[keyval_dwords] +} + +radix_sort_vk_t * +radix_sort_vk_create(VkDevice device, + VkAllocationCallbacks const * ac, + VkPipelineCache pc, + const uint32_t* const* spv, + const uint32_t* spv_sizes, + struct radix_sort_vk_target_config config) +{ + // + // Allocate radix_sort_vk + // + struct radix_sort_vk * const rs = malloc(sizeof(*rs)); + + // + // Save the config for layer + // + rs->config = config; + + // + // How many pipelines? + // + uint32_t const pipeline_count = rs_pipeline_count(rs); + + // + // Prepare to create pipelines + // + VkPushConstantRange const pcr[] = { + { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // + .offset = 0, + .size = sizeof(struct rs_push_init) }, + + { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // + .offset = 0, + .size = sizeof(struct rs_push_fill) }, + + { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // + .offset = 0, + .size = sizeof(struct rs_push_histogram) }, + + { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // + .offset = 0, + .size = sizeof(struct rs_push_prefix) }, + + { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // + .offset = 0, + .size = sizeof(struct rs_push_scatter) }, // scatter_0_even + + { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // + .offset = 0, + .size = sizeof(struct rs_push_scatter) }, // scatter_0_odd + + { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // + .offset = 0, + .size = sizeof(struct rs_push_scatter) }, // scatter_1_even + + { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, // + .offset = 0, + .size = sizeof(struct rs_push_scatter) }, // scatter_1_odd + }; + + VkPipelineLayoutCreateInfo plci = { + + .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, + .pNext = NULL, + .flags = 0, + .setLayoutCount = 0, + .pSetLayouts = NULL, + .pushConstantRangeCount = 1, + // .pPushConstantRanges = pcr + ii; + }; + + for (uint32_t ii = 0; ii < pipeline_count; ii++) + { + plci.pPushConstantRanges = pcr + ii; + + vk(CreatePipelineLayout(device, &plci, NULL, rs->pipeline_layouts.handles + ii)); + } + + // + // Create compute pipelines + // + VkShaderModuleCreateInfo smci = { + + .sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, + .pNext = NULL, + .flags = 0, + // .codeSize = ar_entries[...].size; + // .pCode = ar_data + ...; + }; + + VkShaderModule sms[ARRAY_LENGTH_MACRO(rs->pipelines.handles)]; + + for (uint32_t ii = 0; ii < pipeline_count; ii++) + { + smci.codeSize = spv_sizes[ii]; + smci.pCode = spv[ii]; + + vk(CreateShaderModule(device, &smci, ac, sms + ii)); + } + + // + // If necessary, set the expected subgroup size + // +#define RS_SUBGROUP_SIZE_CREATE_INFO_SET(size_) \ + { \ + .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT, \ + .pNext = NULL, \ + .requiredSubgroupSize = size_, \ + } + +#undef RS_SUBGROUP_SIZE_CREATE_INFO_NAME +#define RS_SUBGROUP_SIZE_CREATE_INFO_NAME(name_) \ + RS_SUBGROUP_SIZE_CREATE_INFO_SET(1 << config.name_.subgroup_size_log2) + +#define RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(name_) RS_SUBGROUP_SIZE_CREATE_INFO_SET(0) + + VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT const rsscis[] = { + RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(init), // init + RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(fill), // fill + RS_SUBGROUP_SIZE_CREATE_INFO_NAME(histogram), // histogram + RS_SUBGROUP_SIZE_CREATE_INFO_NAME(prefix), // prefix + RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter), // scatter[0].even + RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter), // scatter[0].odd + RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter), // scatter[1].even + RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter), // scatter[1].odd + }; + + // + // Define compute pipeline create infos + // +#undef RS_COMPUTE_PIPELINE_CREATE_INFO_DECL +#define RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(idx_) \ + { .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, \ + .pNext = NULL, \ + .flags = 0, \ + .stage = { .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, \ + .pNext = NULL, \ + .flags = 0, \ + .stage = VK_SHADER_STAGE_COMPUTE_BIT, \ + .module = sms[idx_], \ + .pName = "main", \ + .pSpecializationInfo = NULL }, \ + \ + .layout = rs->pipeline_layouts.handles[idx_], \ + .basePipelineHandle = VK_NULL_HANDLE, \ + .basePipelineIndex = 0 } + + VkComputePipelineCreateInfo cpcis[] = { + RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(0), // init + RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(1), // fill + RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(2), // histogram + RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(3), // prefix + RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(4), // scatter[0].even + RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(5), // scatter[0].odd + RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(6), // scatter[1].even + RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(7), // scatter[1].odd + }; + + // + // Which of these compute pipelines require subgroup size control? + // + for (uint32_t ii = 0; ii < pipeline_count; ii++) + { + if (rsscis[ii].requiredSubgroupSize > 1) + { + cpcis[ii].stage.pNext = rsscis + ii; + } + } + + // + // Create the compute pipelines + // + vk(CreateComputePipelines(device, pc, pipeline_count, cpcis, ac, rs->pipelines.handles)); + + // + // Shader modules can be destroyed now + // + for (uint32_t ii = 0; ii < pipeline_count; ii++) + { + vkDestroyShaderModule(device, sms[ii], ac); + } + +#ifdef RS_VK_ENABLE_DEBUG_UTILS + // + // Tag pipelines with names + // + rs_debug_utils_set(device, rs); +#endif + + // + // Calculate "internal" buffer offsets + // + size_t const keyval_bytes = rs->config.keyval_dwords * sizeof(uint32_t); + + // the .range calculation assumes an 8-bit radix + rs->internal.histograms.offset = 0; + rs->internal.histograms.range = keyval_bytes * (RS_RADIX_SIZE * sizeof(uint32_t)); + + // + // NOTE(allanmac): The partitions.offset must be aligned differently if + // RS_RADIX_LOG2 is less than the target's subgroup size log2. At this time, + // no GPU that meets this criteria. + // + rs->internal.partitions.offset = rs->internal.histograms.offset + rs->internal.histograms.range; + + return rs; +} + +// +// +// +void +radix_sort_vk_destroy(struct radix_sort_vk * rs, VkDevice d, VkAllocationCallbacks const * const ac) +{ + uint32_t const pipeline_count = rs_pipeline_count(rs); + + // destroy pipelines + for (uint32_t ii = 0; ii < pipeline_count; ii++) + { + vkDestroyPipeline(d, rs->pipelines.handles[ii], ac); + } + + // destroy pipeline layouts + for (uint32_t ii = 0; ii < pipeline_count; ii++) + { + vkDestroyPipelineLayout(d, rs->pipeline_layouts.handles[ii], ac); + } + + free(rs); +} + +// +// +// +static VkDeviceAddress +rs_get_devaddr(VkDevice device, VkDescriptorBufferInfo const * dbi) +{ + VkBufferDeviceAddressInfo const bdai = { + + .sType = VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, + .pNext = NULL, + .buffer = dbi->buffer + }; + + VkDeviceAddress const devaddr = vkGetBufferDeviceAddress(device, &bdai) + dbi->offset; + + return devaddr; +} + +// +// +// +#ifdef RS_VK_ENABLE_EXTENSIONS + +void +rs_ext_cmd_write_timestamp(struct radix_sort_vk_ext_timestamps * ext_timestamps, + VkCommandBuffer cb, + VkPipelineStageFlagBits pipeline_stage) +{ + if ((ext_timestamps != NULL) && + (ext_timestamps->timestamps_set < ext_timestamps->timestamp_count)) + { + vkCmdWriteTimestamp(cb, + pipeline_stage, + ext_timestamps->timestamps, + ext_timestamps->timestamps_set++); + } +} + +#endif + +// +// +// + +#ifdef RS_VK_ENABLE_EXTENSIONS + +struct radix_sort_vk_ext_base +{ + void * ext; + enum radix_sort_vk_ext_type type; +}; + +#endif + +// +// +// +void +radix_sort_vk_sort_devaddr(radix_sort_vk_t const * rs, + radix_sort_vk_sort_devaddr_info_t const * info, + VkDevice device, + VkCommandBuffer cb, + VkDeviceAddress * keyvals_sorted) +{ + // + // Anything to do? + // + if ((info->count <= 1) || (info->key_bits == 0)) + { + *keyvals_sorted = info->keyvals_even.devaddr; + + return; + } + +#ifdef RS_VK_ENABLE_EXTENSIONS + // + // Any extensions? + // + struct radix_sort_vk_ext_timestamps * ext_timestamps = NULL; + + void * ext_next = info->ext; + + while (ext_next != NULL) + { + struct radix_sort_vk_ext_base * const base = ext_next; + + switch (base->type) + { + case RS_VK_EXT_TIMESTAMPS: + ext_timestamps = ext_next; + ext_timestamps->timestamps_set = 0; + break; + } + + ext_next = base->ext; + } +#endif + + //////////////////////////////////////////////////////////////////////// + // + // OVERVIEW + // + // 1. Pad the keyvals in `scatter_even`. + // 2. Zero the `histograms` and `partitions`. + // --- BARRIER --- + // 3. HISTOGRAM is dispatched before PREFIX. + // --- BARRIER --- + // 4. PREFIX is dispatched before the first SCATTER. + // --- BARRIER --- + // 5. One or more SCATTER dispatches. + // + // Note that the `partitions` buffer can be zeroed anytime before the first + // scatter. + // + //////////////////////////////////////////////////////////////////////// + + // + // Label the command buffer + // +#ifdef RS_VK_ENABLE_DEBUG_UTILS + if (pfn_vkCmdBeginDebugUtilsLabelEXT != NULL) + { + VkDebugUtilsLabelEXT const label = { + .sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT, + .pNext = NULL, + .pLabelName = "radix_sort_vk_sort", + }; + + pfn_vkCmdBeginDebugUtilsLabelEXT(cb, &label); + } +#endif + + // + // How many passes? + // + uint32_t const keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t); + uint32_t const keyval_bits = keyval_bytes * 8; + uint32_t const key_bits = MIN_MACRO(uint32_t, info->key_bits, keyval_bits); + uint32_t const passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2; + + *keyvals_sorted = ((passes & 1) != 0) ? info->keyvals_odd : info->keyvals_even.devaddr; + + //////////////////////////////////////////////////////////////////////// + // + // PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS + // + // Pad fractional blocks with max-valued keyvals. + // + // Zero the histograms and partitions buffer. + // + // This assumes the partitions follow the histograms. + // +#ifdef RS_VK_ENABLE_EXTENSIONS + rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT); +#endif + + // + // FIXME(allanmac): Consider precomputing some of these values and hang them + // off `rs`. + // + + // + // How many scatter blocks? + // + uint32_t const scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2; + uint32_t const scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows; + uint32_t const scatter_blocks = (info->count + scatter_block_kvs - 1) / scatter_block_kvs; + uint32_t const count_ru_scatter = scatter_blocks * scatter_block_kvs; + + // + // How many histogram blocks? + // + // Note that it's OK to have more max-valued digits counted by the histogram + // than sorted by the scatters because the sort is stable. + // + uint32_t const histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2; + uint32_t const histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows; + uint32_t const histo_blocks = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs; + uint32_t const count_ru_histo = histo_blocks * histo_block_kvs; + + // + // Fill with max values + // + if (count_ru_histo > info->count) + { + info->fill_buffer(cb, + &info->keyvals_even, + info->count * keyval_bytes, + (count_ru_histo - info->count) * keyval_bytes, + 0xFFFFFFFF); + } + + // + // Zero histograms and invalidate partitions. + // + // Note that the partition invalidation only needs to be performed once + // because the even/odd scatter dispatches rely on the the previous pass to + // leave the partitions in an invalid state. + // + // Note that the last workgroup doesn't read/write a partition so it doesn't + // need to be initialized. + // + uint32_t const histo_partition_count = passes + scatter_blocks - 1; + uint32_t pass_idx = (keyval_bytes - passes); + + VkDeviceSize const fill_base = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t)); + + info->fill_buffer(cb, + &info->internal, + rs->internal.histograms.offset + fill_base, + histo_partition_count * (RS_RADIX_SIZE * sizeof(uint32_t)), + 0); + + //////////////////////////////////////////////////////////////////////// + // + // Pipeline: HISTOGRAM + // + // TODO(allanmac): All subgroups should try to process approximately the same + // number of blocks in order to minimize tail effects. This was implemented + // and reverted but should be reimplemented and benchmarked later. + // +#ifdef RS_VK_ENABLE_EXTENSIONS + rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TRANSFER_BIT); +#endif + + vk_barrier_transfer_w_to_compute_r(cb); + + // clang-format off + VkDeviceAddress const devaddr_histograms = info->internal.devaddr + rs->internal.histograms.offset; + VkDeviceAddress const devaddr_keyvals_even = info->keyvals_even.devaddr; + // clang-format on + + // + // Dispatch histogram + // + struct rs_push_histogram const push_histogram = { + + .devaddr_histograms = devaddr_histograms, + .devaddr_keyvals = devaddr_keyvals_even, + .passes = passes + }; + + vkCmdPushConstants(cb, + rs->pipeline_layouts.named.histogram, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + sizeof(push_histogram), + &push_histogram); + + vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram); + + vkCmdDispatch(cb, histo_blocks, 1, 1); + + //////////////////////////////////////////////////////////////////////// + // + // Pipeline: PREFIX + // + // Launch one workgroup per pass. + // +#ifdef RS_VK_ENABLE_EXTENSIONS + rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); +#endif + + vk_barrier_compute_w_to_compute_r(cb); + + struct rs_push_prefix const push_prefix = { + + .devaddr_histograms = devaddr_histograms, + }; + + vkCmdPushConstants(cb, + rs->pipeline_layouts.named.prefix, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + sizeof(push_prefix), + &push_prefix); + + vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix); + + vkCmdDispatch(cb, passes, 1, 1); + + //////////////////////////////////////////////////////////////////////// + // + // Pipeline: SCATTER + // +#ifdef RS_VK_ENABLE_EXTENSIONS + rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); +#endif + + vk_barrier_compute_w_to_compute_r(cb); + + // clang-format off + uint32_t const histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t)); + VkDeviceAddress const devaddr_keyvals_odd = info->keyvals_odd; + VkDeviceAddress const devaddr_partitions = info->internal.devaddr + rs->internal.partitions.offset; + // clang-format on + + struct rs_push_scatter push_scatter = { + + .devaddr_keyvals_even = devaddr_keyvals_even, + .devaddr_keyvals_odd = devaddr_keyvals_odd, + .devaddr_partitions = devaddr_partitions, + .devaddr_histograms = devaddr_histograms + histogram_offset, + .pass_offset = (pass_idx & 3) * RS_RADIX_LOG2, + }; + + { + uint32_t const pass_dword = pass_idx / 4; + + vkCmdPushConstants(cb, + rs->pipeline_layouts.named.scatter[pass_dword].even, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + sizeof(push_scatter), + &push_scatter); + + vkCmdBindPipeline(cb, + VK_PIPELINE_BIND_POINT_COMPUTE, + rs->pipelines.named.scatter[pass_dword].even); + } + + bool is_even = true; + + while (true) + { + vkCmdDispatch(cb, scatter_blocks, 1, 1); + + // + // Continue? + // + if (++pass_idx >= keyval_bytes) + break; + +#ifdef RS_VK_ENABLE_EXTENSIONS + rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); +#endif + vk_barrier_compute_w_to_compute_r(cb); + + // clang-format off + is_even ^= true; + push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t)); + push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2; + // clang-format on + + uint32_t const pass_dword = pass_idx / 4; + + // + // Update push constants that changed + // + VkPipelineLayout const pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even // + : rs->pipeline_layouts.named.scatter[pass_dword].odd; + vkCmdPushConstants(cb, + pl, + VK_SHADER_STAGE_COMPUTE_BIT, + OFFSETOF_MACRO(struct rs_push_scatter, devaddr_histograms), + sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset), + &push_scatter.devaddr_histograms); + + // + // Bind new pipeline + // + VkPipeline const p = is_even ? rs->pipelines.named.scatter[pass_dword].even // + : rs->pipelines.named.scatter[pass_dword].odd; + + vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, p); + } + +#ifdef RS_VK_ENABLE_EXTENSIONS + rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); +#endif + + // + // End the label + // +#ifdef RS_VK_ENABLE_DEBUG_UTILS + if (pfn_vkCmdEndDebugUtilsLabelEXT != NULL) + { + pfn_vkCmdEndDebugUtilsLabelEXT(cb); + } +#endif +} + +// +// +// +void +radix_sort_vk_sort_indirect_devaddr(radix_sort_vk_t const * rs, + radix_sort_vk_sort_indirect_devaddr_info_t const * info, + VkDevice device, + VkCommandBuffer cb, + VkDeviceAddress * keyvals_sorted) +{ + // + // Anything to do? + // + if (info->key_bits == 0) + { + *keyvals_sorted = info->keyvals_even; + return; + } + +#ifdef RS_VK_ENABLE_EXTENSIONS + // + // Any extensions? + // + struct radix_sort_vk_ext_timestamps * ext_timestamps = NULL; + + void * ext_next = info->ext; + + while (ext_next != NULL) + { + struct radix_sort_vk_ext_base * const base = ext_next; + + switch (base->type) + { + case RS_VK_EXT_TIMESTAMPS: + ext_timestamps = ext_next; + ext_timestamps->timestamps_set = 0; + break; + } + + ext_next = base->ext; + } +#endif + + //////////////////////////////////////////////////////////////////////// + // + // OVERVIEW + // + // 1. Init + // --- BARRIER --- + // 2. Pad the keyvals in `scatter_even`. + // 3. Zero the `histograms` and `partitions`. + // --- BARRIER --- + // 4. HISTOGRAM is dispatched before PREFIX. + // --- BARRIER --- + // 5. PREFIX is dispatched before the first SCATTER. + // --- BARRIER --- + // 6. One or more SCATTER dispatches. + // + // Note that the `partitions` buffer can be zeroed anytime before the first + // scatter. + // + //////////////////////////////////////////////////////////////////////// + + // + // Label the command buffer + // +#ifdef RS_VK_ENABLE_DEBUG_UTILS + if (pfn_vkCmdBeginDebugUtilsLabelEXT != NULL) + { + VkDebugUtilsLabelEXT const label = { + .sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT, + .pNext = NULL, + .pLabelName = "radix_sort_vk_sort_indirect", + }; + + pfn_vkCmdBeginDebugUtilsLabelEXT(cb, &label); + } +#endif + + // + // How many passes? + // + uint32_t const keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t); + uint32_t const keyval_bits = keyval_bytes * 8; + uint32_t const key_bits = MIN_MACRO(uint32_t, info->key_bits, keyval_bits); + uint32_t const passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2; + uint32_t pass_idx = (keyval_bytes - passes); + + *keyvals_sorted = ((passes & 1) != 0) ? info->keyvals_odd : info->keyvals_even; + + // + // NOTE(allanmac): Some of these initializations appear redundant but for now + // we're going to assume the compiler will elide them. + // + // clang-format off + VkDeviceAddress const devaddr_info = info->indirect.devaddr; + VkDeviceAddress const devaddr_count = info->count; + VkDeviceAddress const devaddr_histograms = info->internal + rs->internal.histograms.offset; + VkDeviceAddress const devaddr_keyvals_even = info->keyvals_even; + // clang-format on + + // + // START + // +#ifdef RS_VK_ENABLE_EXTENSIONS + rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT); +#endif + + // + // INIT + // + { + struct rs_push_init const push_init = { + + .devaddr_info = devaddr_info, + .devaddr_count = devaddr_count, + .passes = passes + }; + + vkCmdPushConstants(cb, + rs->pipeline_layouts.named.init, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + sizeof(push_init), + &push_init); + + vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.init); + + vkCmdDispatch(cb, 1, 1, 1); + } + +#ifdef RS_VK_ENABLE_EXTENSIONS + rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); +#endif + + vk_barrier_compute_w_to_indirect_compute_r(cb); + + { + // + // PAD + // + struct rs_push_fill const push_pad = { + + .devaddr_info = devaddr_info + offsetof(struct rs_indirect_info, pad), + .devaddr_dwords = devaddr_keyvals_even, + .dword = 0xFFFFFFFF + }; + + vkCmdPushConstants(cb, + rs->pipeline_layouts.named.fill, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + sizeof(push_pad), + &push_pad); + + vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.fill); + + info->dispatch_indirect(cb, &info->indirect, offsetof(struct rs_indirect_info, dispatch.pad)); + } + + // + // ZERO + // + { + VkDeviceSize const histo_offset = pass_idx * (sizeof(uint32_t) * RS_RADIX_SIZE); + + struct rs_push_fill const push_zero = { + + .devaddr_info = devaddr_info + offsetof(struct rs_indirect_info, zero), + .devaddr_dwords = devaddr_histograms + histo_offset, + .dword = 0 + }; + + vkCmdPushConstants(cb, + rs->pipeline_layouts.named.fill, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + sizeof(push_zero), + &push_zero); + + vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.fill); + + info->dispatch_indirect(cb, &info->indirect, offsetof(struct rs_indirect_info, dispatch.zero)); + } + +#ifdef RS_VK_ENABLE_EXTENSIONS + rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); +#endif + + vk_barrier_compute_w_to_compute_r(cb); + + // + // HISTOGRAM + // + { + struct rs_push_histogram const push_histogram = { + + .devaddr_histograms = devaddr_histograms, + .devaddr_keyvals = devaddr_keyvals_even, + .passes = passes + }; + + vkCmdPushConstants(cb, + rs->pipeline_layouts.named.histogram, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + sizeof(push_histogram), + &push_histogram); + + vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram); + + info->dispatch_indirect(cb, + &info->indirect, + offsetof(struct rs_indirect_info, dispatch.histogram)); + } + +#ifdef RS_VK_ENABLE_EXTENSIONS + rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); +#endif + + vk_barrier_compute_w_to_compute_r(cb); + + // + // PREFIX + // + { + struct rs_push_prefix const push_prefix = { + .devaddr_histograms = devaddr_histograms, + }; + + vkCmdPushConstants(cb, + rs->pipeline_layouts.named.prefix, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + sizeof(push_prefix), + &push_prefix); + + vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix); + + vkCmdDispatch(cb, passes, 1, 1); + } + +#ifdef RS_VK_ENABLE_EXTENSIONS + rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); +#endif + + vk_barrier_compute_w_to_compute_r(cb); + + // + // SCATTER + // + { + // clang-format off + uint32_t const histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t)); + VkDeviceAddress const devaddr_keyvals_odd = info->keyvals_odd; + VkDeviceAddress const devaddr_partitions = info->internal + rs->internal.partitions.offset; + // clang-format on + + struct rs_push_scatter push_scatter = { + .devaddr_keyvals_even = devaddr_keyvals_even, + .devaddr_keyvals_odd = devaddr_keyvals_odd, + .devaddr_partitions = devaddr_partitions, + .devaddr_histograms = devaddr_histograms + histogram_offset, + .pass_offset = (pass_idx & 3) * RS_RADIX_LOG2, + }; + + { + uint32_t const pass_dword = pass_idx / 4; + + vkCmdPushConstants(cb, + rs->pipeline_layouts.named.scatter[pass_dword].even, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + sizeof(push_scatter), + &push_scatter); + + vkCmdBindPipeline(cb, + VK_PIPELINE_BIND_POINT_COMPUTE, + rs->pipelines.named.scatter[pass_dword].even); + } + + bool is_even = true; + + while (true) + { + info->dispatch_indirect(cb, + &info->indirect, + offsetof(struct rs_indirect_info, dispatch.scatter)); + + // + // Continue? + // + if (++pass_idx >= keyval_bytes) + break; + +#ifdef RS_VK_ENABLE_EXTENSIONS + rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); +#endif + + vk_barrier_compute_w_to_compute_r(cb); + + // clang-format off + is_even ^= true; + push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t)); + push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2; + // clang-format on + + uint32_t const pass_dword = pass_idx / 4; + + // + // Update push constants that changed + // + VkPipelineLayout const pl = is_even + ? rs->pipeline_layouts.named.scatter[pass_dword].even // + : rs->pipeline_layouts.named.scatter[pass_dword].odd; + vkCmdPushConstants( + cb, + pl, + VK_SHADER_STAGE_COMPUTE_BIT, + OFFSETOF_MACRO(struct rs_push_scatter, devaddr_histograms), + sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset), + &push_scatter.devaddr_histograms); + + // + // Bind new pipeline + // + VkPipeline const p = is_even ? rs->pipelines.named.scatter[pass_dword].even // + : rs->pipelines.named.scatter[pass_dword].odd; + + vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, p); + } + } + +#ifdef RS_VK_ENABLE_EXTENSIONS + rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT); +#endif + + // + // End the label + // +#ifdef RS_VK_ENABLE_DEBUG_UTILS + if (pfn_vkCmdEndDebugUtilsLabelEXT != NULL) + { + pfn_vkCmdEndDebugUtilsLabelEXT(cb); + } +#endif +} + +// +// +// +static void +radix_sort_vk_fill_buffer(VkCommandBuffer cb, + radix_sort_vk_buffer_info_t const * buffer_info, + VkDeviceSize offset, + VkDeviceSize size, + uint32_t data) +{ + vkCmdFillBuffer(cb, buffer_info->buffer, buffer_info->offset + offset, size, data); +} + +// +// +// +void +radix_sort_vk_sort(radix_sort_vk_t const * rs, + radix_sort_vk_sort_info_t const * info, + VkDevice device, + VkCommandBuffer cb, + VkDescriptorBufferInfo * keyvals_sorted) +{ + struct radix_sort_vk_sort_devaddr_info const di = { + .ext = info->ext, + .key_bits = info->key_bits, + .count = info->count, + .keyvals_even = { .buffer = info->keyvals_even.buffer, + .offset = info->keyvals_even.offset, + .devaddr = rs_get_devaddr(device, &info->keyvals_even) }, + .keyvals_odd = rs_get_devaddr(device, &info->keyvals_odd), + .internal = { .buffer = info->internal.buffer, + .offset = info->internal.offset, + .devaddr = rs_get_devaddr(device, &info->internal), }, + .fill_buffer = radix_sort_vk_fill_buffer, + }; + + VkDeviceAddress di_keyvals_sorted; + + radix_sort_vk_sort_devaddr(rs, &di, device, cb, &di_keyvals_sorted); + + *keyvals_sorted = (di_keyvals_sorted == di.keyvals_even.devaddr) // + ? info->keyvals_even + : info->keyvals_odd; +} + +// +// +// +static void +radix_sort_vk_dispatch_indirect(VkCommandBuffer cb, + radix_sort_vk_buffer_info_t const * buffer_info, + VkDeviceSize offset) +{ + vkCmdDispatchIndirect(cb, buffer_info->buffer, buffer_info->offset + offset); +} + +// +// +// +void +radix_sort_vk_sort_indirect(radix_sort_vk_t const * rs, + radix_sort_vk_sort_indirect_info_t const * info, + VkDevice device, + VkCommandBuffer cb, + VkDescriptorBufferInfo * keyvals_sorted) +{ + struct radix_sort_vk_sort_indirect_devaddr_info const idi = { + .ext = info->ext, + .key_bits = info->key_bits, + .count = rs_get_devaddr(device, &info->count), + .keyvals_even = rs_get_devaddr(device, &info->keyvals_even), + .keyvals_odd = rs_get_devaddr(device, &info->keyvals_odd), + .internal = rs_get_devaddr(device, &info->internal), + .indirect = { .buffer = info->indirect.buffer, + .offset = info->indirect.offset, + .devaddr = rs_get_devaddr(device, &info->indirect) }, + .dispatch_indirect = radix_sort_vk_dispatch_indirect, + }; + + VkDeviceAddress idi_keyvals_sorted; + + radix_sort_vk_sort_indirect_devaddr(rs, &idi, device, cb, &idi_keyvals_sorted); + + *keyvals_sorted = (idi_keyvals_sorted == idi.keyvals_even) // + ? info->keyvals_even + : info->keyvals_odd; +} + +// +// +// diff --git a/src/amd/vulkan/radix_sort/radix_sort_vk.h b/src/amd/vulkan/radix_sort/radix_sort_vk.h new file mode 100644 index 00000000000..a1416ef133f --- /dev/null +++ b/src/amd/vulkan/radix_sort/radix_sort_vk.h @@ -0,0 +1,384 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_INCLUDE_RADIX_SORT_PLATFORMS_VK_RADIX_SORT_VK_H_ +#define SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_INCLUDE_RADIX_SORT_PLATFORMS_VK_RADIX_SORT_VK_H_ + +// +// +// + +#include + +// +// +// + +#include +#include + +// +// +// + +#include "target.h" + +// +// Radix Sort Vk is a high-performance sorting library for Vulkan 1.2. +// +// The sorting function is both directly and indirectly dispatchable. +// + +#ifdef __cplusplus +extern "C" { +#endif + +// +// Get a Radix Sort target's Vulkan requirements. +// +// A Radix Sort target is a binary image containing configuration parameters and +// a bundle of SPIR-V modules. +// +// Targets are prebuilt and specific to a particular device vendor, architecture +// and key-val configuration. +// +// A Radix Sort instance can only be created with a VkDevice that is initialized +// with all of the target's required extensions and features. +// +// The `radix_sort_vk_target_get_requirements()` function yields the extensions +// and initialized feature flags required by a Radix Sort target. +// +// These requirements can be merged with other Vulkan library requirements +// before VkDevice creation. +// +// If the `.ext_names` member is NULL, the `.ext_name_count` member will be +// initialized. +// +// Returns `false` if: +// +// * The .ext_names field is NULL and the number of required extensions is +// greater than zero. +// * The .ext_name_count is less than the number of required extensions is +// greater than zero. +// * Any of the .pdf, .pdf11 or .pdf12 members are NULL. +// +// Otherwise, returns true. +// +typedef struct radix_sort_vk_target radix_sort_vk_target_t; + +// +// NOTE: The library currently supports uint32_t and uint64_t keyvals. +// + +#define RS_KV_DWORDS_MAX 2 + +// +// +// + +struct rs_pipeline_layout_scatter +{ + VkPipelineLayout even; + VkPipelineLayout odd; +}; + +struct rs_pipeline_scatter +{ + VkPipeline even; + VkPipeline odd; +}; + +// +// +// + +struct rs_pipeline_layouts_named +{ + VkPipelineLayout init; + VkPipelineLayout fill; + VkPipelineLayout histogram; + VkPipelineLayout prefix; + struct rs_pipeline_layout_scatter scatter[RS_KV_DWORDS_MAX]; +}; + +struct rs_pipelines_named +{ + VkPipeline init; + VkPipeline fill; + VkPipeline histogram; + VkPipeline prefix; + struct rs_pipeline_scatter scatter[RS_KV_DWORDS_MAX]; +}; + +// clang-format off +#define RS_PIPELINE_LAYOUTS_HANDLES (sizeof(struct rs_pipeline_layouts_named) / sizeof(VkPipelineLayout)) +#define RS_PIPELINES_HANDLES (sizeof(struct rs_pipelines_named) / sizeof(VkPipeline)) +// clang-format on + +// +// +// + +struct radix_sort_vk +{ + struct radix_sort_vk_target_config config; + + union + { + struct rs_pipeline_layouts_named named; + VkPipelineLayout handles[RS_PIPELINE_LAYOUTS_HANDLES]; + } pipeline_layouts; + + union + { + struct rs_pipelines_named named; + VkPipeline handles[RS_PIPELINES_HANDLES]; + } pipelines; + + struct + { + struct + { + VkDeviceSize offset; + VkDeviceSize range; + } histograms; + + struct + { + VkDeviceSize offset; + } partitions; + + } internal; +}; + +// +// Create a Radix Sort instance for a target.(VkCommandBuffer cb, +// +// Keyval size is implicitly determined by the target. +// +// Returns NULL on failure. +// +typedef struct radix_sort_vk radix_sort_vk_t; + +// +// +// +radix_sort_vk_t * +radix_sort_vk_create(VkDevice device, + VkAllocationCallbacks const * ac, + VkPipelineCache pc, + const uint32_t* const* spv, + const uint32_t* spv_sizes, + struct radix_sort_vk_target_config config); + +// +// Destroy the Radix Sort instance using the same device and allocator used at +// creation. +// +void +radix_sort_vk_destroy(radix_sort_vk_t * rs, // + VkDevice d, + VkAllocationCallbacks const * ac); + +// +// Returns the buffer size and alignment requirements for a maximum number of +// keyvals. +// +// The radix sort implementation is not an in-place sorting algorithm so two +// non-overlapping keyval buffers are required that are at least +// `.keyvals_size`. +// +// The radix sort instance also requires an `internal` buffer during sorting. +// +// If the indirect dispatch sorting function is used, then an `indirect` buffer +// is also required. +// +// The alignment requirements for the keyval, internal, and indirect buffers +// must be honored. All alignments are power of 2. +// +// Input: +// count : Maximum number of keyvals +// +// Outputs: +// keyval_size : Size of a single keyval +// +// keyvals_size : Minimum size of the even and odd keyval buffers +// keyvals_alignment : Alignment of each keyval buffer +// +// internal_size : Minimum size of internal buffer +// internal_aligment : Alignment of the internal buffer +// +// indirect_size : Minimum size of indirect buffer +// indirect_aligment : Alignment of the indirect buffer +// +// .keyvals_even/odd +// ----------------- +// VK_BUFFER_USAGE_STORAGE_BUFFER_BIT +// VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT +// +// .internal +// --------- +// VK_BUFFER_USAGE_STORAGE_BUFFER_BIT +// VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT +// VK_BUFFER_USAGE_TRANSFER_DST_BIT ("direct" mode only) +// +// .indirect +// --------- +// VK_BUFFER_USAGE_STORAGE_BUFFER_BIT +// VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT +// VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT +// +typedef struct radix_sort_vk_memory_requirements +{ + VkDeviceSize keyval_size; + + VkDeviceSize keyvals_size; + VkDeviceSize keyvals_alignment; + + VkDeviceSize internal_size; + VkDeviceSize internal_alignment; + + VkDeviceSize indirect_size; + VkDeviceSize indirect_alignment; +} radix_sort_vk_memory_requirements_t; + +void +radix_sort_vk_get_memory_requirements(radix_sort_vk_t const * rs, + uint32_t count, + radix_sort_vk_memory_requirements_t * mr); + +// +// Direct dispatch sorting +// ----------------------- +// +// Using a key size of `key_bits`, sort `count` keyvals found in the +// `.devaddr_keyvals_even` buffer. +// +// Each internal sorting pass copies the keyvals from one keyvals buffer to the +// other. +// +// The number of internal sorting passes is determined by `.key_bits`. +// +// If an even number of internal sorting passes is required, the sorted keyvals +// will be found in the "even" keyvals buffer. Otherwise, the sorted keyvals +// will be found in the "odd" keyvals buffer. +// +// Which buffer has the sorted keyvals is returned in `keyvals_sorted`. +// +// A keyval's `key_bits` are the most significant bits of a keyval. +// +// The maximum number of key bits is determined by the keyval size. +// +// The keyval count must be less than (1 << 30) as well as be less than or equal +// to the count used to obtain the the memory requirements. +// +// The info struct's `ext` member must be NULL. +// +// This function appends push constants, dispatch commands, and barriers. +// +// Pipeline barriers should be applied as necessary, both before and after +// invoking this function. +// +// The sort begins with either a TRANSFER/WRITE or a COMPUTE/READ to the +// `internal` and `keyvals_even` buffers. +// +// The sort ends with a COMPUTE/WRITE to the `internal` and `keyvals_sorted` +// buffers. +// + +// +// Direct dispatch sorting using VkDescriptorBufferInfo structures +// --------------------------------------------------------------- +// +typedef struct radix_sort_vk_sort_info +{ + void * ext; + uint32_t key_bits; + uint32_t count; + VkDescriptorBufferInfo keyvals_even; + VkDescriptorBufferInfo keyvals_odd; + VkDescriptorBufferInfo internal; +} radix_sort_vk_sort_info_t; + +void +radix_sort_vk_sort(radix_sort_vk_t const * rs, + radix_sort_vk_sort_info_t const * info, + VkDevice device, + VkCommandBuffer cb, + VkDescriptorBufferInfo * keyvals_sorted); + +// +// Indirect dispatch sorting +// ------------------------- +// +// Using a key size of `key_bits`, at pipeline execution time, load keyvals +// count from `devaddr_count` and sorts the keyvals in `.devaddr_keyvals_even`. +// +// Each internal sorting pass copies the keyvals from one keyvals buffer to the +// other. +// +// The number of internal sorting passes is determined by `.key_bits`. +// +// If an even number of internal sorting passes is required, the sorted keyvals +// will be found in the "even" keyvals buffer. Otherwise, the sorted keyvals +// will be found in the "odd" keyvals buffer. +// +// Which buffer has the sorted keyvals is returned in `keyvals_sorted`. +// +// A keyval's `key_bits` are the most significant bits of a keyval. +// +// The keyval count must be less than (1 << 30) as well as be less than or equal +// to the count used to obtain the the memory requirements. +// +// The info struct's `ext` member must be NULL. +// +// This function appends push constants, dispatch commands, and barriers. +// +// Pipeline barriers should be applied as necessary, both before and after +// invoking this function. +// +// The indirect radix sort begins with a COMPUTE/READ from the `count` buffer +// and ends with a COMPUTE/WRITE to the `internal` and the `keyvals_sorted` +// buffers. +// +// The `indirect` buffer must support USAGE_INDIRECT. +// +// The `count` buffer must be at least 4 bytes and 4-byte aligned. +// + +// +// Indirect dispatch sorting using VkDescriptorBufferInfo structures +// ----------------------------------------------------------------- +// +typedef struct radix_sort_vk_sort_indirect_info +{ + void * ext; + uint32_t key_bits; + VkDescriptorBufferInfo count; + VkDescriptorBufferInfo keyvals_even; + VkDescriptorBufferInfo keyvals_odd; + VkDescriptorBufferInfo internal; + VkDescriptorBufferInfo indirect; +} radix_sort_vk_sort_indirect_info_t; + +void +radix_sort_vk_sort_indirect(radix_sort_vk_t const * rs, + radix_sort_vk_sort_indirect_info_t const * info, + VkDevice device, + VkCommandBuffer cb, + VkDescriptorBufferInfo * keyvals_sorted); + +// +// +// + +#ifdef __cplusplus +} +#endif + +// +// +// + +#endif // SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_INCLUDE_RADIX_SORT_PLATFORMS_VK_RADIX_SORT_VK_H_ diff --git a/src/amd/vulkan/radix_sort/radix_sort_vk_devaddr.h b/src/amd/vulkan/radix_sort/radix_sort_vk_devaddr.h new file mode 100644 index 00000000000..23dd808249a --- /dev/null +++ b/src/amd/vulkan/radix_sort/radix_sort_vk_devaddr.h @@ -0,0 +1,104 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_INCLUDE_RADIX_SORT_PLATFORMS_VK_RADIX_SORT_VK_DEVADDR_H_ +#define SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_INCLUDE_RADIX_SORT_PLATFORMS_VK_RADIX_SORT_VK_DEVADDR_H_ + +// +// +// + +#include "radix_sort_vk.h" + +// +// +// + +#ifdef __cplusplus +extern "C" { +#endif + +// +// Structure that enables integration with Vulkan drivers. +// +typedef struct radix_sort_vk_buffer_info +{ + VkBuffer buffer; + VkDeviceSize offset; + VkDeviceAddress devaddr; +} radix_sort_vk_buffer_info_t; + +// +// Function prototypes +// +typedef void (*radix_sort_vk_fill_buffer_pfn)(VkCommandBuffer cb, + radix_sort_vk_buffer_info_t const * buffer_info, + VkDeviceSize offset, + VkDeviceSize size, + uint32_t data); + +typedef void (*radix_sort_vk_dispatch_indirect_pfn)(VkCommandBuffer cb, + radix_sort_vk_buffer_info_t const * buffer_info, + VkDeviceSize offset); + +// +// Direct dispatch sorting using buffer device addresses +// ----------------------------------------------------- +// +typedef struct radix_sort_vk_sort_devaddr_info +{ + void * ext; + uint32_t key_bits; + uint32_t count; + radix_sort_vk_buffer_info_t keyvals_even; + VkDeviceAddress keyvals_odd; + radix_sort_vk_buffer_info_t internal; + radix_sort_vk_fill_buffer_pfn fill_buffer; +} radix_sort_vk_sort_devaddr_info_t; + +void +radix_sort_vk_sort_devaddr(radix_sort_vk_t const * rs, + radix_sort_vk_sort_devaddr_info_t const * info, + VkDevice device, + VkCommandBuffer cb, + VkDeviceAddress * keyvals_sorted); + +// +// Indirect dispatch sorting using buffer device addresses +// ------------------------------------------------------- +// +// clang-format off +// +typedef struct radix_sort_vk_sort_indirect_devaddr_info +{ + void * ext; + uint32_t key_bits; + VkDeviceAddress count; + VkDeviceAddress keyvals_even; + VkDeviceAddress keyvals_odd; + VkDeviceAddress internal; + radix_sort_vk_buffer_info_t indirect; + radix_sort_vk_dispatch_indirect_pfn dispatch_indirect; +} radix_sort_vk_sort_indirect_devaddr_info_t; + +void +radix_sort_vk_sort_indirect_devaddr(radix_sort_vk_t const * rs, + radix_sort_vk_sort_indirect_devaddr_info_t const * info, + VkDevice device, + VkCommandBuffer cb, + VkDeviceAddress * keyvals_sorted); + +// +// clang-format on +// + +#ifdef __cplusplus +} +#endif + +// +// +// + +#endif // SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_INCLUDE_RADIX_SORT_PLATFORMS_VK_RADIX_SORT_VK_DEVADDR_H_ diff --git a/src/amd/vulkan/radix_sort/radix_sort_vk_ext.h b/src/amd/vulkan/radix_sort/radix_sort_vk_ext.h new file mode 100644 index 00000000000..af55c706934 --- /dev/null +++ b/src/amd/vulkan/radix_sort/radix_sort_vk_ext.h @@ -0,0 +1,77 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_RADIX_SORT_VK_EXT_H_ +#define SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_RADIX_SORT_VK_EXT_H_ + +// +// +// + +#include + +// +// +// + +#include +#include + +// +// +// + +#ifdef __cplusplus +extern "C" { +#endif + +// +// Radix sort extensions +// --------------------- +// +#ifndef RADIX_SORT_VK_DISABLE_EXTENSIONS + +// +// Extension types +// +enum radix_sort_vk_ext_type +{ + RADIX_SORT_VK_EXT_TIMESTAMPS +}; + +// +// Timestamp each logical step of the algorithm +// +// Number of timestamps is: 5 + (number of subpasses) +// +// * direct dispatch: 4 + subpass count +// * indirect dispatch: 5 + subpass count +// +// Indirect / 32-bit keyvals: 9 +// Indirect / 64-bit keyvals: 13 +// +struct radix_sort_vk_ext_timestamps +{ + void * ext; + enum radix_sort_vk_ext_type type; + uint32_t timestamp_count; + VkQueryPool timestamps; + uint32_t timestamps_set; +}; + +#endif + +// +// +//dsc + +#ifdef __cplusplus +} +#endif + +// +// +// + +#endif // SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_RADIX_SORT_VK_EXT_H_ diff --git a/src/amd/vulkan/radix_sort/radv_radix_sort.c b/src/amd/vulkan/radix_sort/radv_radix_sort.c new file mode 100644 index 00000000000..88b17bdd503 --- /dev/null +++ b/src/amd/vulkan/radix_sort/radv_radix_sort.c @@ -0,0 +1,193 @@ +/* + * Copyright © 2022 Konstantin Seurer + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice (including the next + * paragraph) shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#include "radv_radix_sort.h" +#include "targets/u64/config.h" +#include "radv_private.h" +#include "target.h" + +static const uint32_t init_spv[] = { +#include "radix_sort/shaders/init.comp.spv.h" +}; + +static const uint32_t fill_spv[] = { +#include "radix_sort/shaders/fill.comp.spv.h" +}; + +static const uint32_t histogram_spv[] = { +#include "radix_sort/shaders/histogram.comp.spv.h" +}; + +static const uint32_t prefix_spv[] = { +#include "radix_sort/shaders/prefix.comp.spv.h" +}; + +static const uint32_t scatter_0_even_spv[] = { +#include "radix_sort/shaders/scatter_0_even.comp.spv.h" +}; + +static const uint32_t scatter_0_odd_spv[] = { +#include "radix_sort/shaders/scatter_0_odd.comp.spv.h" +}; + +static const uint32_t scatter_1_even_spv[] = { +#include "radix_sort/shaders/scatter_1_even.comp.spv.h" +}; + +static const uint32_t scatter_1_odd_spv[] = { +#include "radix_sort/shaders/scatter_1_odd.comp.spv.h" +}; + +static const struct radix_sort_vk_target_config target_config = { + .keyval_dwords = RS_KEYVAL_DWORDS, + + .histogram = + { + .workgroup_size_log2 = RS_HISTOGRAM_WORKGROUP_SIZE_LOG2, + .subgroup_size_log2 = RS_HISTOGRAM_SUBGROUP_SIZE_LOG2, + .block_rows = RS_HISTOGRAM_BLOCK_ROWS, + }, + + .prefix = + { + .workgroup_size_log2 = RS_PREFIX_WORKGROUP_SIZE_LOG2, + .subgroup_size_log2 = RS_PREFIX_SUBGROUP_SIZE_LOG2, + }, + + .scatter = + { + .workgroup_size_log2 = RS_SCATTER_WORKGROUP_SIZE_LOG2, + .subgroup_size_log2 = RS_SCATTER_SUBGROUP_SIZE_LOG2, + .block_rows = RS_SCATTER_BLOCK_ROWS, + }, +}; + +radix_sort_vk_t * +radv_create_radix_sort_u64(VkDevice device, VkAllocationCallbacks const *ac, VkPipelineCache pc) +{ + const uint32_t *spv[8] = { + init_spv, fill_spv, histogram_spv, prefix_spv, + scatter_0_even_spv, scatter_0_odd_spv, scatter_1_even_spv, scatter_1_odd_spv, + }; + const uint32_t spv_sizes[8] = { + sizeof(init_spv), sizeof(fill_spv), sizeof(histogram_spv), + sizeof(prefix_spv), sizeof(scatter_0_even_spv), sizeof(scatter_0_odd_spv), + sizeof(scatter_1_even_spv), sizeof(scatter_1_odd_spv), + }; + return radix_sort_vk_create(device, ac, pc, spv, spv_sizes, target_config); +} + +VKAPI_ATTR VkResult VKAPI_CALL +vkCreateShaderModule(VkDevice device, const VkShaderModuleCreateInfo *pCreateInfo, + const VkAllocationCallbacks *pAllocator, VkShaderModule *pShaderModule) +{ + RADV_FROM_HANDLE(radv_device, pdevice, device); + return pdevice->vk.dispatch_table.CreateShaderModule(device, pCreateInfo, pAllocator, + pShaderModule); +} + +VKAPI_ATTR void VKAPI_CALL +vkDestroyShaderModule(VkDevice device, VkShaderModule shaderModule, + const VkAllocationCallbacks *pAllocator) +{ + RADV_FROM_HANDLE(radv_device, pdevice, device); + pdevice->vk.dispatch_table.DestroyShaderModule(device, shaderModule, pAllocator); +} + +VKAPI_ATTR VkResult VKAPI_CALL +vkCreatePipelineLayout(VkDevice device, const VkPipelineLayoutCreateInfo *pCreateInfo, + const VkAllocationCallbacks *pAllocator, VkPipelineLayout *pPipelineLayout) +{ + RADV_FROM_HANDLE(radv_device, pdevice, device); + return pdevice->vk.dispatch_table.CreatePipelineLayout(device, pCreateInfo, pAllocator, + pPipelineLayout); +} + +VKAPI_ATTR void VKAPI_CALL +vkDestroyPipelineLayout(VkDevice device, VkPipelineLayout pipelineLayout, + const VkAllocationCallbacks *pAllocator) +{ + RADV_FROM_HANDLE(radv_device, pdevice, device); + pdevice->vk.dispatch_table.DestroyPipelineLayout(device, pipelineLayout, pAllocator); +} + +VKAPI_ATTR VkResult VKAPI_CALL +vkCreateComputePipelines(VkDevice device, VkPipelineCache pipelineCache, uint32_t createInfoCount, + const VkComputePipelineCreateInfo *pCreateInfos, + const VkAllocationCallbacks *pAllocator, VkPipeline *pPipelines) +{ + RADV_FROM_HANDLE(radv_device, pdevice, device); + return pdevice->vk.dispatch_table.CreateComputePipelines(device, pipelineCache, createInfoCount, + pCreateInfos, pAllocator, pPipelines); +} + +VKAPI_ATTR void VKAPI_CALL +vkDestroyPipeline(VkDevice device, VkPipeline pipeline, const VkAllocationCallbacks *pAllocator) +{ + RADV_FROM_HANDLE(radv_device, pdevice, device); + return pdevice->vk.dispatch_table.DestroyPipeline(device, pipeline, pAllocator); +} + +VKAPI_ATTR void VKAPI_CALL +vkCmdPipelineBarrier(VkCommandBuffer commandBuffer, VkPipelineStageFlags srcStageMask, + VkPipelineStageFlags dstStageMask, VkDependencyFlags dependencyFlags, + uint32_t memoryBarrierCount, const VkMemoryBarrier *pMemoryBarriers, + uint32_t bufferMemoryBarrierCount, + const VkBufferMemoryBarrier *pBufferMemoryBarriers, + uint32_t imageMemoryBarrierCount, + const VkImageMemoryBarrier *pImageMemoryBarriers) +{ + RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer); + cmd_buffer->device->vk.dispatch_table.CmdPipelineBarrier( + commandBuffer, srcStageMask, dstStageMask, dependencyFlags, memoryBarrierCount, + pMemoryBarriers, bufferMemoryBarrierCount, pBufferMemoryBarriers, imageMemoryBarrierCount, + pImageMemoryBarriers); +} + +VKAPI_ATTR void VKAPI_CALL +vkCmdPushConstants(VkCommandBuffer commandBuffer, VkPipelineLayout layout, + VkShaderStageFlags stageFlags, uint32_t offset, uint32_t size, + const void *pValues) +{ + RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer); + cmd_buffer->device->vk.dispatch_table.CmdPushConstants(commandBuffer, layout, stageFlags, offset, + size, pValues); +} + +VKAPI_ATTR void VKAPI_CALL +vkCmdBindPipeline(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipelineBindPoint, + VkPipeline pipeline) +{ + RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer); + cmd_buffer->device->vk.dispatch_table.CmdBindPipeline(commandBuffer, pipelineBindPoint, + pipeline); +} + +VKAPI_ATTR void VKAPI_CALL +vkCmdDispatch(VkCommandBuffer commandBuffer, uint32_t groupCountX, uint32_t groupCountY, + uint32_t groupCountZ) +{ + RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer); + cmd_buffer->device->vk.dispatch_table.CmdDispatch(commandBuffer, groupCountX, groupCountY, + groupCountZ); +} diff --git a/src/amd/vulkan/radix_sort/radv_radix_sort.h b/src/amd/vulkan/radix_sort/radv_radix_sort.h new file mode 100644 index 00000000000..513049f467a --- /dev/null +++ b/src/amd/vulkan/radix_sort/radv_radix_sort.h @@ -0,0 +1,32 @@ +/* + * Copyright © 2022 Konstantin Seurer + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice (including the next + * paragraph) shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef RADV_RADIX_SORT_H +#define RADV_RADIX_SORT_H + +#include "radix_sort_vk_devaddr.h" + +radix_sort_vk_t *radv_create_radix_sort_u64(VkDevice device, VkAllocationCallbacks const *ac, + VkPipelineCache pc); + +#endif diff --git a/src/amd/vulkan/radix_sort/shaders/bufref.h b/src/amd/vulkan/radix_sort/shaders/bufref.h new file mode 100644 index 00000000000..ea3319db60c --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/bufref.h @@ -0,0 +1,151 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_BUFREF_H_ +#define SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_BUFREF_H_ + +// +// GLSL +// + +#ifdef VULKAN // defined by GLSL/VK compiler + +#extension GL_EXT_shader_explicit_arithmetic_types : require + +// +// If the target does not support VkPhysicalDeviceFeatures.shaderInt64 +// then: +// +// #define RS_DISABLE_SHADER_INT64 +// +// clang-format off +#ifdef RS_DISABLE_SHADER_INT64 +#extension GL_EXT_buffer_reference_uvec2 : require +#else +#extension GL_EXT_buffer_reference2 : require +#endif +// clang-format on + +// +// Restrict shouldn't have any noticeable impact on these kernels and +// benchmarks appear to prove that true but it's correct to include +// the qualifier. +// +#define RS_RESTRICT restrict + +// +// If the device doesn't support .shaderInt64 then the buffer reference address +// is a uvec2. +// +#ifdef RS_DISABLE_SHADER_INT64 +#define RS_DEVADDR u32vec2 +#else +#define RS_DEVADDR uint64_t +#endif + +// +// Define a buffer reference. +// +#define RS_BUFREF_DEFINE(_layout, _name, _devaddr) RS_RESTRICT _layout _name = _layout(_devaddr) + +// +// Define a buffer reference at a UINT32 offset. +// +#ifdef RS_DISABLE_SHADER_INT64 +#define RS_BUFREF_DEFINE_AT_OFFSET_UINT32(_layout, _name, _devaddr_u32vec2, _offset) \ + RS_RESTRICT _layout _name; \ + { \ + u32vec2 devaddr; \ + uint32_t carry; \ + \ + devaddr.x = uaddCarry(_devaddr_u32vec2.x, _offset, carry); \ + devaddr.y = _devaddr_u32vec2.y + carry; \ + \ + _name = _layout(devaddr); \ + } +#else +#define RS_BUFREF_DEFINE_AT_OFFSET_UINT32(_layout, _name, _devaddr, _offset) \ + RS_RESTRICT _layout _name = _layout(_devaddr + _offset) +#endif + +// +// Define a buffer reference at a packed UINT64 offset. +// +#ifdef RS_DISABLE_SHADER_INT64 +#define RS_BUFREF_DEFINE_AT_OFFSET_U32VEC2(_layout, _name, _devaddr_u32vec2, _offset_u32vec2) \ + RS_RESTRICT _layout _name; \ + { \ + u32vec2 devaddr; \ + uint32_t carry; \ + \ + devaddr.x = uaddCarry(_devaddr_u32vec2.x, _offset_u32vec2.x, carry); \ + devaddr.y = _devaddr_u32vec2.y + _offset_u32vec2.y + carry; \ + \ + _name = _layout(devaddr); \ + } +#else +#define RS_BUFREF_DEFINE_AT_OFFSET_U32VEC2(_layout, _name, _devaddr, _offset_u32vec2) \ + RS_RESTRICT _layout _name = _layout(_devaddr + pack64(_offset_u32vec2)) +#endif + +// +// Increment the buffer reference by a UINT32 offset. +// +#ifdef RS_DISABLE_SHADER_INT64 +#define RS_BUFREF_INC_UINT32(_layout, _name, _inc) \ + { \ + u32vec2 devaddr = u32vec2(_name); \ + uint32_t carry; \ + \ + devaddr.x = uaddCarry(devaddr.x, _inc, carry); \ + devaddr.y = devaddr.y + carry; \ + \ + _name = _layout(devaddr); \ + } +#else +#define RS_BUFREF_INC_UINT32(_layout, _name, _inc) _name = _layout(uint64_t(_name) + _inc) +#endif + +// +// Increment the buffer reference by a packed UINT64 offset. +// +#ifdef RS_DISABLE_SHADER_INT64 +#define RS_BUFREF_INC_U32VEC2(_layout, _name, _inc_u32vec2) \ + { \ + u32vec2 devaddr = u32vec2(_name); \ + uint32_t carry; \ + \ + devaddr.x = uaddCarry(devaddr.x, _inc_u32vec2.x, carry); \ + devaddr.y = devaddr.y + _inc_u32vec2.y + carry; \ + \ + _name = _layout(devaddr); \ + } +#else +#define RS_BUFREF_INC_U32VEC2(_layout, _name, _inc_u32vec2) \ + _name = _layout(uint64_t(_name) + pack64(_inc_u32vec2)) +#endif + +// +// Increment the buffer reference by the product of two UINT32 factors. +// +#define RS_BUFREF_INC_UINT32_UINT32(_layout, _name, _inc_a, _inc_b) \ + { \ + u32vec2 inc; \ + \ + umulExtended(_inc_a, _inc_b, inc.y, inc.x); \ + \ + RS_BUFREF_INC_U32VEC2(_layout, _name, inc); \ + } + +// +// +// + +#endif + +// +// +// + +#endif // SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_BUFREF_H_ diff --git a/src/amd/vulkan/radix_sort/shaders/fill.comp b/src/amd/vulkan/radix_sort/shaders/fill.comp new file mode 100644 index 00000000000..76b446d8c5d --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/fill.comp @@ -0,0 +1,143 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#version 460 + +// +// Each workgroup fills up to RS_BLOCK_KEYVALS +// + +// clang-format off +#extension GL_GOOGLE_include_directive : require +#extension GL_EXT_control_flow_attributes : require +// clang-format on + +// +// Load arch/keyval configuration +// +#include "config.h" + +// +// Buffer reference macros and push constants +// +#include "bufref.h" +#include "push.h" + +// +// Subgroup uniform support +// +#if defined(RS_SCATTER_SUBGROUP_UNIFORM_DISABLE) && defined(GL_EXT_subgroupuniform_qualifier) +#extension GL_EXT_subgroupuniform_qualifier : required +#define RS_SUBGROUP_UNIFORM subgroupuniformEXT +#else +#define RS_SUBGROUP_UNIFORM +#endif + +// +// Declare the push constants +// +RS_STRUCT_PUSH_FILL(); + +layout(push_constant) uniform block_push +{ + rs_push_fill push; +}; + +// +// The "init" shader configures the fill info structure. +// +RS_STRUCT_INDIRECT_INFO_FILL(); + +// +// Check all switches are defined +// +#ifndef RS_FILL_WORKGROUP_SIZE_LOG2 +#error "Undefined: RS_FILL_WORKGROUP_SIZE_LOG2" +#endif + +// +#ifndef RS_FILL_BLOCK_ROWS +#error "Undefined: RS_FILL_BLOCK_ROWS" +#endif + +// +// Local macros +// +// clang-format off +#define RS_WORKGROUP_SIZE (1 << RS_FILL_WORKGROUP_SIZE_LOG2) +#define RS_BLOCK_DWORDS (RS_FILL_BLOCK_ROWS * RS_WORKGROUP_SIZE) +#define RS_RADIX_MASK ((1 << RS_RADIX_LOG2) - 1) +// clang-format on + +// +// +// +layout(local_size_x = RS_WORKGROUP_SIZE) in; + +// +// +// +layout(buffer_reference, std430) buffer buffer_rs_indirect_info_fill +{ + rs_indirect_info_fill info; +}; + +layout(buffer_reference, std430) buffer buffer_rs_dwords +{ + uint32_t extent[]; +}; + +// +// +// +void +main() +{ + // + // Define indirect info bufref for the fill + // + readonly RS_BUFREF_DEFINE(buffer_rs_indirect_info_fill, rs_info, push.devaddr_info); + + RS_SUBGROUP_UNIFORM const rs_indirect_info_fill info = rs_info.info; + + // + // Define dwords bufref + // + // Assumes less than 2^32-1 keys and then extended multiplies it by + // the keyval size. + // + // Assumes push.devaddr_dwords_base is suitably aligned to + // RS_BLOCK_DWORDS -- at a subgroup or transaction size is fine. + // + const uint32_t dwords_idx = + (info.block_offset + gl_WorkGroupID.x) * RS_BLOCK_DWORDS + gl_LocalInvocationID.x; + + u32vec2 dwords_offset; + + umulExtended(dwords_idx, 4, dwords_offset.y, dwords_offset.x); + + writeonly RS_BUFREF_DEFINE_AT_OFFSET_U32VEC2(buffer_rs_dwords, + rs_dwords, + push.devaddr_dwords, + dwords_offset); + + // + // Fills are always aligned to RS_BLOCK_KEYVALS + // + // ((v >= min) && (v < max)) == ((v - min) < (max - min)) + // + const uint32_t row_idx = dwords_idx - info.dword_offset_min; + + [[unroll]] for (uint32_t ii = 0; ii < RS_FILL_BLOCK_ROWS; ii++) + { + if (row_idx + (ii * RS_WORKGROUP_SIZE) < info.dword_offset_max_minus_min) + { + rs_dwords.extent[ii * RS_WORKGROUP_SIZE] = push.dword; + } + } +} + +// +// +// diff --git a/src/amd/vulkan/radix_sort/shaders/histogram.comp b/src/amd/vulkan/radix_sort/shaders/histogram.comp new file mode 100644 index 00000000000..7d554630fe5 --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/histogram.comp @@ -0,0 +1,449 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#version 460 + +// +// Produce multiple radix size histograms from the keyvals. +// + +// clang-format off +#extension GL_GOOGLE_include_directive : require +#extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : require +// clang-format on + +// +// +// +#include "config.h" + +// +// Optional switches: +// +// #define RS_HISTOGRAM_ENABLE_BITFIELD_EXTRACT +// #define RS_HISTOGRAM_DISABLE_SMEM_HISTOGRAM +// + +// +// Buffer reference macros and push constants +// +#include "bufref.h" +#include "push.h" + +// +// Push constants for histogram shader +// +RS_STRUCT_PUSH_HISTOGRAM(); + +layout(push_constant) uniform block_push +{ + rs_push_histogram push; +}; + +// +// Subgroup uniform support +// +#if defined(RS_HISTOGRAM_SUBGROUP_UNIFORM_DISABLE) && defined(GL_EXT_subgroupuniform_qualifier) +#extension GL_EXT_subgroupuniform_qualifier : required +#define RS_SUBGROUP_UNIFORM subgroupuniformEXT +#else +#define RS_SUBGROUP_UNIFORM +#endif + +// +// Check all switches are defined +// + +// What's the size of the keyval? +#ifndef RS_KEYVAL_DWORDS +#error "Undefined: RS_KEYVAL_DWORDS" +#endif + +// +#ifndef RS_HISTOGRAM_BLOCK_ROWS +#error "Undefined: RS_HISTOGRAM_BLOCK_ROWS" +#endif + +// +#ifndef RS_HISTOGRAM_WORKGROUP_SIZE_LOG2 +#error "Undefined: RS_HISTOGRAM_WORKGROUP_SIZE_LOG2" +#endif + +// +#ifndef RS_HISTOGRAM_SUBGROUP_SIZE_LOG2 +#error "Undefined: RS_HISTOGRAM_SUBGROUP_SIZE_LOG2" +#endif + +// +// Local macros +// +// clang-format off +#define RS_WORKGROUP_SIZE (1 << RS_HISTOGRAM_WORKGROUP_SIZE_LOG2) +#define RS_SUBGROUP_SIZE (1 << RS_HISTOGRAM_SUBGROUP_SIZE_LOG2) +#define RS_WORKGROUP_SUBGROUPS (RS_WORKGROUP_SIZE / RS_SUBGROUP_SIZE) +#define RS_BLOCK_KEYVALS (RS_HISTOGRAM_BLOCK_ROWS * RS_WORKGROUP_SIZE) +#define RS_KEYVAL_SIZE (RS_KEYVAL_DWORDS * 4) +#define RS_RADIX_MASK ((1 << RS_RADIX_LOG2) - 1) +// clang-format on + +// +// Keyval type +// +#if (RS_KEYVAL_DWORDS == 1) +#define RS_KEYVAL_TYPE uint32_t +#elif (RS_KEYVAL_DWORDS == 2) +#define RS_KEYVAL_TYPE u32vec2 +#else +#error "Unsupported RS_KEYVAL_DWORDS" +#endif + +// +// Histogram offset depends on number of workgroups. +// +#define RS_HISTOGRAM_BASE(pass_) ((RS_RADIX_SIZE * 4) * pass_) + +#if (RS_WORKGROUP_SUBGROUPS == 1) +#define RS_HISTOGRAM_OFFSET(pass_) (RS_HISTOGRAM_BASE(pass_) + gl_SubgroupInvocationID * 4) +#else +#define RS_HISTOGRAM_OFFSET(pass_) (RS_HISTOGRAM_BASE(pass_) + gl_LocalInvocationID.x * 4) +#endif + +// +// Assumes (RS_RADIX_LOG2 == 8) +// +// Error if this ever changes +// +#if (RS_RADIX_LOG2 != 8) +#error "(RS_RADIX_LOG2 != 8)" +#endif + +// +// Is bitfield extract faster? +// +#ifdef RS_HISTOGRAM_ENABLE_BITFIELD_EXTRACT +//---------------------------------------------------------------------- + +// +// Extract a keyval digit +// +#if (RS_KEYVAL_DWORDS == 1) +#define RS_KV_EXTRACT_DIGIT(kv_, pass_) bitfieldExtract(kv_, pass_ * RS_RADIX_LOG2, RS_RADIX_LOG2) +#else +#define RS_KV_EXTRACT_DIGIT(kv_, pass_) \ + bitfieldExtract(kv_[pass_ / 4], (pass_ & 3) * RS_RADIX_LOG2, RS_RADIX_LOG2) +#endif +//---------------------------------------------------------------------- +#else +//---------------------------------------------------------------------- + +// +// Extract a keyval digit +// +#if (RS_KEYVAL_DWORDS == 1) +#define RS_KV_EXTRACT_DIGIT(kv_, pass_) ((kv_ >> (pass_ * RS_RADIX_LOG2)) & RS_RADIX_MASK) +#else +#define RS_KV_EXTRACT_DIGIT(kv_, pass_) \ + ((kv_[pass_ / 4] >> ((pass_ & 3) * RS_RADIX_LOG2)) & RS_RADIX_MASK) +#endif +//---------------------------------------------------------------------- +#endif + +// +// +// +#ifndef RS_HISTOGRAM_DISABLE_SMEM_HISTOGRAM + +struct rs_histogram_smem +{ + uint32_t histogram[RS_RADIX_SIZE]; +}; + +shared rs_histogram_smem smem; + +#endif + +// +// +// +layout(local_size_x = RS_WORKGROUP_SIZE) in; + +// +// +// +layout(buffer_reference, std430) buffer buffer_rs_kv +{ + RS_KEYVAL_TYPE extent[]; +}; + +layout(buffer_reference, std430) buffer buffer_rs_histograms +{ + uint32_t extent[]; +}; + +// +// Shared memory functions +// +#ifndef RS_HISTOGRAM_DISABLE_SMEM_HISTOGRAM + +// +// NOTE: Must use same access pattern as rs_histogram_zero() +// +void +rs_histogram_zero() +{ + // + // Zero SMEM histogram + // +#if (RS_WORKGROUP_SUBGROUPS == 1) + + const uint32_t smem_offset = gl_SubgroupInvocationID; + + [[unroll]] for (RS_SUBGROUP_UNIFORM uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE) + { + smem.histogram[smem_offset + ii] = 0; + } + +#elif (RS_WORKGROUP_SIZE < RS_RADIX_SIZE) + + const uint32_t smem_offset = gl_LocalInvocationID.x; + + [[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE) + { + smem.histogram[smem_offset + ii] = 0; + } + + const uint32_t smem_idx = smem_offset + ((RS_RADIX_SIZE / RS_WORKGROUP_SIZE) * RS_WORKGROUP_SIZE); + + if (smem_idx < RS_RADIX_SIZE) + { + smem.histogram[smem_idx] = 0; + } + +#elif (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE) + +#if (RS_WORKGROUP_SIZE > RS_RADIX_SIZE) + if (gl_LocalInvocationID.x < RS_RADIX_SIZE) +#endif + { + smem.histogram[gl_LocalInvocationID.x] = 0; + } + +#endif +} + +// +// NOTE: Must use same access pattern as rs_histogram_zero() +// +void +rs_histogram_global_store(restrict buffer_rs_histograms rs_histograms) +{ + // + // Store to GMEM + // +#if (RS_WORKGROUP_SUBGROUPS == 1) + + const uint32_t smem_offset = gl_SubgroupInvocationID; + + [[unroll]] for (RS_SUBGROUP_UNIFORM uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE) + { + const uint32_t count = smem.histogram[smem_offset + ii]; + + atomicAdd(rs_histograms.extent[ii], count); + } + +#elif (RS_WORKGROUP_SIZE < RS_RADIX_SIZE) + + const uint32_t smem_offset = gl_LocalInvocationID.x; + + [[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE) + { + const uint32_t count = smem.histogram[smem_offset + ii]; + + atomicAdd(rs_histograms.extent[ii], count); + } + + const uint32_t smem_idx = smem_offset + ((RS_RADIX_SIZE / RS_WORKGROUP_SIZE) * RS_WORKGROUP_SIZE); + + if (smem_idx < RS_RADIX_SIZE) + { + const uint32_t count = smem.histogram[smem_idx]; + + atomicAdd(rs_histograms.extent[((RS_RADIX_SIZE / RS_WORKGROUP_SIZE) * RS_WORKGROUP_SIZE)], + count); + } + +#elif (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE) + +#if (RS_WORKGROUP_SIZE > RS_RADIX_SIZE) + if (gl_LocalInvocationID.x < RS_RADIX_SIZE) +#endif + { + const uint32_t count = smem.histogram[gl_LocalInvocationID.x]; + + atomicAdd(rs_histograms.extent[0], count); + } + +#endif +} + +#endif + +// +// +// +#ifndef RS_HISTOGRAM_DISABLE_SMEM_HISTOGRAM + +void +rs_histogram_atomic_after_write() +{ +#if (RS_WORKGROUP_SUBGROUPS == 1) + subgroupMemoryBarrierShared(); +#else + barrier(); +#endif +} + +void +rs_histogram_read_after_atomic() +{ +#if (RS_WORKGROUP_SUBGROUPS == 1) + subgroupMemoryBarrierShared(); +#else + barrier(); +#endif +} + +#endif + +// +// +// +void +main() +{ + // + // Which subgroups have work? + // + RS_KEYVAL_TYPE kv[RS_HISTOGRAM_BLOCK_ROWS]; + + // + // Define kv_in bufref + // + // Assumes less than 2^30-1 keys and then extended multiplies it + // by the keyval size. + // + u32vec2 kv_in_offset; + + umulExtended(gl_WorkGroupID.x * RS_BLOCK_KEYVALS + gl_LocalInvocationID.x, + RS_KEYVAL_SIZE, + kv_in_offset.y, // msb + kv_in_offset.x); // lsb + + readonly RS_BUFREF_DEFINE_AT_OFFSET_U32VEC2(buffer_rs_kv, + rs_kv_in, + push.devaddr_keyvals, + kv_in_offset); + + // + // Load keyvals + // + [[unroll]] for (RS_SUBGROUP_UNIFORM uint32_t ii = 0; ii < RS_HISTOGRAM_BLOCK_ROWS; ii++) + { + kv[ii] = rs_kv_in.extent[ii * RS_WORKGROUP_SIZE]; + } + + //////////////////////////////////////////////////////////////////////////// + // + // Accumulate and store histograms for passes + // + //////////////////////////////////////////////////////////////////////////// + + //////////////////////////////////////////////////////////////////////////// + // + // MACRO EXPANSION VARIANT + // + // NOTE: THIS ALSO SERVES AS A MALI R24+ WORKAROUND: EXPLICITLY + // EXPAND THE FOR/LOOP PASSES + // +#ifndef RS_HISTOGRAM_DISABLE_SMEM_HISTOGRAM + +#define RS_HISTOGRAM_PASS(pass_) \ + rs_histogram_zero(); \ + \ + rs_histogram_atomic_after_write(); \ + \ + [[unroll]] for (RS_SUBGROUP_UNIFORM uint32_t jj = 0; jj < RS_HISTOGRAM_BLOCK_ROWS; jj++) \ + { \ + const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[jj], pass_); \ + \ + atomicAdd(smem.histogram[digit], 1); \ + } \ + \ + rs_histogram_read_after_atomic(); \ + \ + { \ + const uint32_t rs_histogram_offset = RS_HISTOGRAM_OFFSET(pass_); \ + \ + RS_BUFREF_DEFINE_AT_OFFSET_UINT32(buffer_rs_histograms, \ + rs_histograms, \ + push.devaddr_histograms, \ + rs_histogram_offset); \ + \ + rs_histogram_global_store(rs_histograms); \ + } \ + \ + if (push.passes == (RS_KEYVAL_SIZE - pass_)) \ + return; + +#else // NO SHARED MEMORY + +#define RS_HISTOGRAM_PASS(pass_) \ + { \ + const uint32_t rs_histogram_base = RS_HISTOGRAM_BASE(pass_); \ + \ + RS_BUFREF_DEFINE_AT_OFFSET_UINT32(buffer_rs_histograms, \ + rs_histograms, \ + push.devaddr_histograms, \ + rs_histogram_base); \ + \ + [[unroll]] for (RS_SUBGROUP_UNIFORM uint32_t jj = 0; jj < RS_HISTOGRAM_BLOCK_ROWS; jj++) \ + { \ + const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[jj], pass_); \ + \ + atomicAdd(rs_histograms.extent[digit], 1); \ + } \ + } \ + \ + if (push.passes == (RS_KEYVAL_SIZE - pass_)) \ + return; + +#endif + +#if (RS_KEYVAL_DWORDS == 1) + + RS_HISTOGRAM_PASS(3) + RS_HISTOGRAM_PASS(2) + RS_HISTOGRAM_PASS(1) + RS_HISTOGRAM_PASS(0) + +#elif (RS_KEYVAL_DWORDS == 2) + + RS_HISTOGRAM_PASS(7) + RS_HISTOGRAM_PASS(6) + RS_HISTOGRAM_PASS(5) + RS_HISTOGRAM_PASS(4) + RS_HISTOGRAM_PASS(3) + RS_HISTOGRAM_PASS(2) + RS_HISTOGRAM_PASS(1) + RS_HISTOGRAM_PASS(0) + +#else +#error "Error: (RS_KEYVAL_DWORDS >= 3) not implemented." +#endif +} + +// +// +// diff --git a/src/amd/vulkan/radix_sort/shaders/init.comp b/src/amd/vulkan/radix_sort/shaders/init.comp new file mode 100644 index 00000000000..1ffd48d79df --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/init.comp @@ -0,0 +1,168 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#version 460 + +// +// Initialize the `rs_indirect_info` struct +// + +// clang-format off +#extension GL_GOOGLE_include_directive : require +#extension GL_EXT_control_flow_attributes : require +// clang-format on + +// +// Load arch/keyval configuration +// +#include "config.h" + +// +// Buffer reference macros and push constants +// +#include "bufref.h" +#include "push.h" + +// +// Subgroup uniform support +// +#if defined(RS_SCATTER_SUBGROUP_UNIFORM_DISABLE) && defined(GL_EXT_subgroupuniform_qualifier) +#extension GL_EXT_subgroupuniform_qualifier : required +#define RS_SUBGROUP_UNIFORM subgroupuniformEXT +#else +#define RS_SUBGROUP_UNIFORM +#endif + +// +// Declare the push constants +// +RS_STRUCT_PUSH_INIT(); + +layout(push_constant) uniform block_push +{ + rs_push_init push; +}; + +// +// The "init" shader configures the fill info structure. +// +RS_STRUCT_INDIRECT_INFO(); + +// +// Local macros +// +// clang-format off +#define RS_FILL_WORKGROUP_SIZE (1 << RS_FILL_WORKGROUP_SIZE_LOG2) +#define RS_SCATTER_WORKGROUP_SIZE (1 << RS_SCATTER_WORKGROUP_SIZE_LOG2) +#define RS_HISTOGRAM_WORKGROUP_SIZE (1 << RS_HISTOGRAM_WORKGROUP_SIZE_LOG2) + +#define RS_FILL_BLOCK_DWORDS (RS_FILL_BLOCK_ROWS * RS_FILL_WORKGROUP_SIZE) +#define RS_SCATTER_BLOCK_KEYVALS (RS_SCATTER_BLOCK_ROWS * RS_SCATTER_WORKGROUP_SIZE) +#define RS_HISTOGRAM_BLOCK_KEYVALS (RS_HISTOGRAM_BLOCK_ROWS * RS_HISTOGRAM_WORKGROUP_SIZE) +// clang-format on + +// +// +// +layout(local_size_x = 1) in; + +// +// +// +layout(buffer_reference, std430) buffer buffer_rs_count +{ + uint32_t count; +}; + +layout(buffer_reference, std430) buffer buffer_rs_indirect_info +{ + rs_indirect_info info; +}; + +// +// Helper macros +// +// RU = Round Up +// RD = Round Down +// +#define RS_COUNT_RU_BLOCKS(count_, block_size_) ((count_ + block_size_ - 1) / (block_size_)) +#define RS_COUNT_RD_BLOCKS(count_, block_size_) ((count_) / (block_size_)) + +// +// +// +void +main() +{ + // + // Load the keyval count + // + readonly RS_BUFREF_DEFINE(buffer_rs_count, rs_count, push.devaddr_count); + + RS_SUBGROUP_UNIFORM const uint32_t count = rs_count.count; + + // + // Define the init struct bufref + // + writeonly RS_BUFREF_DEFINE(buffer_rs_indirect_info, rs_indirect_info, push.devaddr_info); + + // + // Size and set scatter dispatch + // + const uint32_t scatter_ru_blocks = RS_COUNT_RU_BLOCKS(count, RS_SCATTER_BLOCK_KEYVALS); + const uint32_t count_ru_scatter = scatter_ru_blocks * RS_SCATTER_BLOCK_KEYVALS; + + rs_indirect_info.info.dispatch.scatter = u32vec4(scatter_ru_blocks, 1, 1, 0); + + // + // Size and set histogram dispatch + // + const uint32_t histo_ru_blocks = RS_COUNT_RU_BLOCKS(count_ru_scatter, RS_HISTOGRAM_BLOCK_KEYVALS); + const uint32_t count_ru_histo = histo_ru_blocks * RS_HISTOGRAM_BLOCK_KEYVALS; + + rs_indirect_info.info.dispatch.histogram = u32vec4(histo_ru_blocks, 1, 1, 0); + + // + // Size and set pad fill and dispatch + // + const uint32_t count_dwords = count * RS_KEYVAL_DWORDS; + const uint32_t pad_rd_blocks = RS_COUNT_RD_BLOCKS(count_dwords, RS_FILL_BLOCK_DWORDS); + const uint32_t count_rd_pad = pad_rd_blocks * RS_FILL_BLOCK_DWORDS; + const uint32_t count_ru_histo_dwords = count_ru_histo * RS_KEYVAL_DWORDS; + const uint32_t pad_dwords = count_ru_histo_dwords - count_rd_pad; + const uint32_t pad_ru_blocks = RS_COUNT_RU_BLOCKS(pad_dwords, RS_FILL_BLOCK_DWORDS); + + rs_indirect_info_fill pad; + + pad.block_offset = pad_rd_blocks; + pad.dword_offset_min = count_dwords; + pad.dword_offset_max_minus_min = count_ru_histo_dwords - count_dwords; + + rs_indirect_info.info.pad = pad; + rs_indirect_info.info.dispatch.pad = u32vec4(pad_ru_blocks, 1, 1, 0); + + // + // Size and set zero fill and dispatch + // + // NOTE(allanmac): We could zero the histogram passes on the host + // since the number of passes is known ahead of time but since the + // 256-dword partitions directly follow the 256-dword histograms, we + // can dispatch just one FILL. + // + rs_indirect_info_fill zero; + + zero.block_offset = 0; + zero.dword_offset_min = 0; + zero.dword_offset_max_minus_min = (push.passes + scatter_ru_blocks - 1) * RS_RADIX_SIZE; + + const uint32_t zero_ru_blocks = + RS_COUNT_RU_BLOCKS(zero.dword_offset_max_minus_min, RS_FILL_BLOCK_DWORDS); + + rs_indirect_info.info.zero = zero; + rs_indirect_info.info.dispatch.zero = u32vec4(zero_ru_blocks, 1, 1, 0); +} + +// +// +// diff --git a/src/amd/vulkan/radix_sort/shaders/meson.build b/src/amd/vulkan/radix_sort/shaders/meson.build new file mode 100644 index 00000000000..e3506fb6654 --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/meson.build @@ -0,0 +1,51 @@ +# Copyright © 2022 Konstantin Seurer + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +radix_sort_shaders = [ + 'init.comp', + 'fill.comp', + 'histogram.comp', + 'prefix.comp', + 'scatter_0_even.comp', + 'scatter_0_odd.comp', + 'scatter_1_even.comp', + 'scatter_1_odd.comp' +] + +shader_include_dir = meson.source_root() + '/src/amd/vulkan/radix_sort/targets/u64' + +shader_include_files = files( + 'bufref.h', + 'prefix_limits.h', + 'prefix.h', + 'push.h', + 'scatter.glsl', + meson.source_root() + '/src/amd/vulkan/radix_sort/targets/u64/config.h' +) + +radix_sort_spv = [] +foreach s : radix_sort_shaders + radix_sort_spv += custom_target( + s + '.spv.h', + input : s, + output : s + '.spv.h', + command : [prog_glslang, '-V', '-I' + shader_include_dir, '--target-env', 'spirv1.3', '-x', '-o', '@OUTPUT@', '@INPUT@'], + depend_files: shader_include_files) +endforeach diff --git a/src/amd/vulkan/radix_sort/shaders/prefix.comp b/src/amd/vulkan/radix_sort/shaders/prefix.comp new file mode 100644 index 00000000000..aae88869a6e --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/prefix.comp @@ -0,0 +1,194 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#version 460 + +// +// Prefix sum the coarse histograms. +// + +// clang-format off +#extension GL_GOOGLE_include_directive : require +#extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_ballot : require +// clang-format on + +// +// +// +#include "config.h" + +// +// Buffer reference macros and push constants +// +#include "bufref.h" +#include "push.h" + +// +// +// +RS_STRUCT_PUSH_PREFIX(); + +layout(push_constant) uniform block_push +{ + rs_push_prefix push; +}; + +// +// Subgroup uniform support +// +#if defined(RS_HISTOGRAM_SUBGROUP_UNIFORM_DISABLE) && defined(GL_EXT_subgroupuniform_qualifier) +#extension GL_EXT_subgroupuniform_qualifier : required +#define RS_SUBGROUP_UNIFORM subgroupuniformEXT +#else +#define RS_SUBGROUP_UNIFORM +#endif + +// +// Check all switches are defined +// +// +#ifndef RS_PREFIX_SUBGROUP_SIZE_LOG2 +#error "Undefined: RS_PREFIX_SUBGROUP_SIZE_LOG2" +#endif + +// +#ifndef RS_PREFIX_WORKGROUP_SIZE_LOG2 +#error "Undefined: RS_PREFIX_WORKGROUP_SIZE_LOG2" +#endif + +// +// Local macros +// +// clang-format off +#define RS_KEYVAL_SIZE (RS_KEYVAL_DWORDS * 4) +#define RS_WORKGROUP_SIZE (1 << RS_PREFIX_WORKGROUP_SIZE_LOG2) +#define RS_SUBGROUP_SIZE (1 << RS_PREFIX_SUBGROUP_SIZE_LOG2) +#define RS_WORKGROUP_SUBGROUPS (RS_WORKGROUP_SIZE / RS_SUBGROUP_SIZE) +// clang-format on + +// +// There is no purpose in having a workgroup size larger than the +// radix size. +// +#if (RS_WORKGROUP_SIZE > RS_RADIX_SIZE) +#error "Error: (RS_WORKGROUP_SIZE > RS_RADIX_SIZE)" +#endif + +// +// +// +layout(local_size_x = RS_WORKGROUP_SIZE) in; + +// +// Histogram buffer reference +// +layout(buffer_reference, std430) buffer buffer_rs_histograms +{ + uint32_t extent[]; +}; + +// +// Load prefix limits before loading function +// +#include "prefix_limits.h" + +// +// If multi-subgroup then define shared memory +// +#if (RS_WORKGROUP_SUBGROUPS > 1) + +//---------------------------------------- +shared uint32_t smem_sweep0[RS_SWEEP_0_SIZE]; + +#define RS_PREFIX_SWEEP0(idx_) smem_sweep0[idx_] +//---------------------------------------- + +#if (RS_SWEEP_1_SIZE > 0) +//---------------------------------------- +shared uint32_t smem_sweep1[RS_SWEEP_1_SIZE]; + +#define RS_PREFIX_SWEEP1(idx_) smem_sweep1[idx_] +//---------------------------------------- +#endif + +#if (RS_SWEEP_2_SIZE > 0) +//---------------------------------------- +shared uint32_t smem_sweep2[RS_SWEEP_2_SIZE]; + +#define RS_PREFIX_SWEEP2(idx_) smem_sweep2[idx_] +//---------------------------------------- +#endif + +#endif + +// +// Define function arguments +// +#define RS_PREFIX_ARGS buffer_rs_histograms rs_histograms + +// +// Define load/store functions +// +// clang-format off +#define RS_PREFIX_LOAD(idx_) rs_histograms.extent[idx_] +#define RS_PREFIX_STORE(idx_) rs_histograms.extent[idx_] +// clang-format on + +// +// Load prefix function +// +#include "prefix.h" + +// +// Exclusive prefix of uint32_t[256] +// +void +main() +{ + // + // Define buffer reference to read histograms + // +#if (RS_WORKGROUP_SUBGROUPS == 1) + // + // Define histograms bufref for single subgroup + // + // NOTE(allanmac): The histogram buffer reference could be adjusted + // on the host to save a couple instructions at the cost of added + // complexity. + // + RS_SUBGROUP_UNIFORM + const uint32_t histograms_base = ((RS_KEYVAL_SIZE - 1 - gl_WorkGroupID.x) * RS_RADIX_SIZE); + const uint32_t histograms_offset = (histograms_base + gl_SubgroupInvocationID) * 4; + + RS_BUFREF_DEFINE_AT_OFFSET_UINT32(buffer_rs_histograms, + rs_histograms, + push.devaddr_histograms, + histograms_offset); + +#else + // + // Define histograms bufref for workgroup + // + RS_SUBGROUP_UNIFORM + const uint32_t histograms_base = ((RS_KEYVAL_SIZE - 1 - gl_WorkGroupID.x) * RS_RADIX_SIZE); + const uint32_t histograms_offset = (histograms_base + gl_LocalInvocationID.x) * 4; + + RS_BUFREF_DEFINE_AT_OFFSET_UINT32(buffer_rs_histograms, + rs_histograms, + push.devaddr_histograms, + histograms_offset); + +#endif + + // + // Compute exclusive prefix of uint32_t[256] + // + rs_prefix(rs_histograms); +} + +// +// +// diff --git a/src/amd/vulkan/radix_sort/shaders/prefix.h b/src/amd/vulkan/radix_sort/shaders/prefix.h new file mode 100644 index 00000000000..f9d470bb3f5 --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/prefix.h @@ -0,0 +1,353 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PREFIX_H_ +#define SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PREFIX_H_ + +// +// Requires several defines +// +#ifndef RS_PREFIX_LIMITS +#error "Error: \"prefix_limits.h\" not loaded" +#endif + +#ifndef RS_PREFIX_ARGS +#error "Error: RS_PREFIX_ARGS undefined" +#endif + +#ifndef RS_PREFIX_LOAD +#error "Error: RS_PREFIX_LOAD undefined" +#endif + +#ifndef RS_PREFIX_STORE +#error "Error: RS_PREFIX_STORE undefined" +#endif + +#ifndef RS_SUBGROUP_SIZE +#error "Error: RS_SUBGROUP_SIZE undefined" +#endif + +#ifndef RS_WORKGROUP_SIZE +#error "Error: RS_WORKGROUP_SIZE undefined" +#endif + +#ifndef RS_WORKGROUP_SUBGROUPS +#error "Error: RS_WORKGROUP_SUBGROUPS undefined" +#endif + +// +// Optional switches: +// +// * Disable holding original inclusively scanned histogram values in registers. +// +// #define RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS +// + +// +// Compute exclusive prefix of uint32_t[256] +// +void +rs_prefix(RS_PREFIX_ARGS) +{ +#if (RS_WORKGROUP_SUBGROUPS == 1) + // + // Workgroup is a single subgroup so no shared memory is required. + // + + // + // Exclusive scan-add the histogram + // + const uint32_t h0 = RS_PREFIX_LOAD(0); + const uint32_t h0_inc = subgroupInclusiveAdd(h0); + RS_SUBGROUP_UNIFORM uint32_t h_last = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1); + + RS_PREFIX_STORE(0) = h0_inc - h0; // exclusive + + // + // Each iteration is dependent on the previous so no unrolling. The + // compiler is free to hoist the loads upward though. + // + for (RS_SUBGROUP_UNIFORM uint32_t ii = RS_SUBGROUP_SIZE; // + ii < RS_RADIX_SIZE; + ii += RS_SUBGROUP_SIZE) + { + const uint32_t h = RS_PREFIX_LOAD(ii); + const uint32_t h_inc = subgroupInclusiveAdd(h) + h_last; + h_last = subgroupBroadcast(h_inc, RS_SUBGROUP_SIZE - 1); + + RS_PREFIX_STORE(ii) = h_inc - h; // exclusive + } + +#else + // + // Workgroup is multiple subgroups and uses shared memory to store + // the scan's intermediate results. + // + // Assumes a power-of-two subgroup, workgroup and radix size. + // + // Downsweep: Repeatedly scan reductions until they fit in a single + // subgroup. + // + // Upsweep: Then uniformly apply reductions to each subgroup. + // + // + // Subgroup Size | 4 | 8 | 16 | 32 | 64 | + // --------------+----+----+----+----+----+ + // Sweep 0 | 64 | 32 | 16 | 8 | 4 | sweep_0[] + // Sweep 1 | 16 | 4 | - | - | - | sweep_1[] + // Sweep 2 | 4 | - | - | - | - | sweep_2[] + // --------------+----+----+----+----+----+ + // Total dwords | 84 | 36 | 16 | 8 | 4 | + // --------------+----+----+----+----+----+ + // +#ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS + uint32_t h_exc[RS_H_COMPONENTS]; +#endif + + // + // Downsweep 0 + // + [[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++) + { + const uint32_t h = RS_PREFIX_LOAD(ii * RS_WORKGROUP_SIZE); + + const uint32_t h_inc = subgroupInclusiveAdd(h); + + const uint32_t smem_idx = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID; + + RS_PREFIX_SWEEP0(smem_idx) = subgroupBroadcast(h_inc, RS_SUBGROUP_SIZE - 1); + + // +#ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS + h_exc[ii] = h_inc - h; +#else + RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) = h_inc - h; +#endif + } + + barrier(); + + // + // Skip generalizing these sweeps for all possible subgroups -- just + // write them directly. + // +#if ((RS_SUBGROUP_SIZE == 64) || (RS_SUBGROUP_SIZE == 32) || (RS_SUBGROUP_SIZE == 16)) + + ////////////////////////////////////////////////////////////////////// + // + // Scan 0 + // +#if (RS_SWEEP_0_SIZE != RS_SUBGROUP_SIZE) + if (gl_LocalInvocationID.x < RS_SWEEP_0_SIZE) // subgroup has inactive invocations +#endif + { + const uint32_t h0_red = RS_PREFIX_SWEEP0(gl_LocalInvocationID.x); + const uint32_t h0_inc = subgroupInclusiveAdd(h0_red); + + RS_PREFIX_SWEEP0(gl_LocalInvocationID.x) = h0_inc - h0_red; + } + +#elif (RS_SUBGROUP_SIZE == 8) + +#if (RS_SWEEP_0_SIZE < RS_WORKGROUP_SIZE) + + ////////////////////////////////////////////////////////////////////// + // + // Scan 0 and Downsweep 1 + // + if (gl_LocalInvocationID.x < RS_SWEEP_0_SIZE) // 32 invocations + { + const uint32_t h0_red = RS_PREFIX_SWEEP0(gl_LocalInvocationID.x); + const uint32_t h0_inc = subgroupInclusiveAdd(h0_red); + + RS_PREFIX_SWEEP0(gl_LocalInvocationID.x) = h0_inc - h0_red; + RS_PREFIX_SWEEP1(gl_SubgroupID) = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1); + } + +#else + + ////////////////////////////////////////////////////////////////////// + // + // Scan 0 and Downsweep 1 + // + [[unroll]] for (uint32_t ii = 0; ii < RS_S0_PASSES; ii++) // 32 invocations + { + const uint32_t idx0 = (ii * RS_WORKGROUP_SIZE) + gl_LocalInvocationID.x; + const uint32_t idx1 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID; + + const uint32_t h0_red = RS_PREFIX_SWEEP0(idx0); + const uint32_t h0_inc = subgroupInclusiveAdd(h0_red); + + RS_PREFIX_SWEEP0(idx0) = h0_inc - h0_red; + RS_PREFIX_SWEEP1(idx1) = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1); + } + +#endif + + barrier(); + + // + // Scan 1 + // + if (gl_LocalInvocationID.x < RS_SWEEP_1_SIZE) // 4 invocations + { + const uint32_t h1_red = RS_PREFIX_SWEEP1(gl_LocalInvocationID.x); + const uint32_t h1_inc = subgroupInclusiveAdd(h1_red); + + RS_PREFIX_SWEEP1(gl_LocalInvocationID.x) = h1_inc - h1_red; + } + +#elif (RS_SUBGROUP_SIZE == 4) + + ////////////////////////////////////////////////////////////////////// + // + // Scan 0 and Downsweep 1 + // +#if (RS_SWEEP_0_SIZE < RS_WORKGROUP_SIZE) + + if (gl_LocalInvocationID.x < RS_SWEEP_0_SIZE) // 64 invocations + { + const uint32_t h0_red = RS_PREFIX_SWEEP0(gl_LocalInvocationID.x); + const uint32_t h0_inc = subgroupInclusiveAdd(h0_red); + + RS_PREFIX_SWEEP0(gl_LocalInvocationID.x) = h0_inc - h0_red; + RS_PREFIX_SWEEP1(gl_SubgroupID) = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1); + } + +#else + + [[unroll]] for (uint32_t ii = 0; ii < RS_S0_PASSES; ii++) // 64 invocations + { + const uint32_t idx0 = (ii * RS_WORKGROUP_SIZE) + gl_LocalInvocationID.x; + const uint32_t idx1 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID; + + const uint32_t h0_red = RS_PREFIX_SWEEP0(idx0); + const uint32_t h0_inc = subgroupInclusiveAdd(h0_red); + + RS_PREFIX_SWEEP0(idx0) = h0_inc - h0_red; + RS_PREFIX_SWEEP1(idx1) = subgroupBroadcast(h0_inc, RS_SUBGROUP_SIZE - 1); + } +#endif + + barrier(); + + // + // Scan 1 and Downsweep 2 + // +#if (RS_SWEEP_1_SIZE < RS_WORKGROUP_SIZE) + if (gl_LocalInvocationID.x < RS_SWEEP_1_SIZE) // 16 invocations + { + const uint32_t h1_red = RS_PREFIX_SWEEP1(gl_LocalInvocationID.x); + const uint32_t h1_inc = subgroupInclusiveAdd(h1_red); + + RS_PREFIX_SWEEP1(gl_LocalInvocationID.x) = h1_inc - h1_red; + RS_PREFIX_SWEEP2(gl_SubgroupID) = subgroupBroadcast(h1_inc, RS_SUBGROUP_SIZE - 1); + } + +#else + + [[unroll]] for (uint32_t ii = 0; ii < RS_S1_PASSES; ii++) // 16 invocations + { + const uint32_t idx1 = (ii * RS_WORKGROUP_SIZE) + gl_LocalInvocationID.x; + const uint32_t idx2 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID; + + const uint32_t h1_red = RS_PREFIX_SWEEP1(idx1); + const uint32_t h1_inc = subgroupInclusiveAdd(h1_red); + + RS_PREFIX_SWEEP1(idx1) = h1_inc - h1_red; + RS_PREFIX_SWEEP2(idx2) = subgroupBroadcast(h1_inc, RS_SUBGROUP_SIZE - 1); + } + +#endif + + barrier(); + + // + // Scan 2 + // + // 4 invocations + // + if (gl_LocalInvocationID.x < RS_SWEEP_2_SIZE) + { + const uint32_t h2_red = RS_PREFIX_SWEEP2(gl_LocalInvocationID.x); + const uint32_t h2_inc = subgroupInclusiveAdd(h2_red); + + RS_PREFIX_SWEEP2(gl_LocalInvocationID.x) = h2_inc - h2_red; + } + +#else +#error "Error: Unsupported subgroup size" +#endif + + barrier(); + + ////////////////////////////////////////////////////////////////////// + // + // Final upsweep 0 + // +#if ((RS_SUBGROUP_SIZE == 64) || (RS_SUBGROUP_SIZE == 32) || (RS_SUBGROUP_SIZE == 16)) + + [[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++) + { + const uint32_t idx0 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID; + + // clang format issue +#ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS + RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) = h_exc[ii] + RS_PREFIX_SWEEP0(idx0); +#else + const uint32_t h_exc = RS_PREFIX_LOAD(ii * RS_WORKGROUP_SIZE); + + RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) = h_exc + RS_PREFIX_SWEEP0(idx0); +#endif + } + +#elif (RS_SUBGROUP_SIZE == 8) + + [[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++) + { + const uint32_t idx0 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID; + const uint32_t idx1 = idx0 / RS_SUBGROUP_SIZE; + +#ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS + RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) = + h_exc[ii] + RS_PREFIX_SWEEP0(idx0) + RS_PREFIX_SWEEP1(idx1); +#else + const uint32_t h_exc = RS_PREFIX_LOAD(ii * RS_WORKGROUP_SIZE); + + RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) = + h_exc + RS_PREFIX_SWEEP0(idx0) + RS_PREFIX_SWEEP1(idx1); +#endif + } + +#elif (RS_SUBGROUP_SIZE == 4) + + [[unroll]] for (uint32_t ii = 0; ii < RS_H_COMPONENTS; ii++) + { + const uint32_t idx0 = (ii * RS_WORKGROUP_SUBGROUPS) + gl_SubgroupID; + const uint32_t idx1 = idx0 / RS_SUBGROUP_SIZE; + const uint32_t idx2 = idx1 / RS_SUBGROUP_SIZE; + +#ifndef RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS + RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) = + h_exc[ii] + (RS_PREFIX_SWEEP0(idx0) + RS_PREFIX_SWEEP1(idx1) + RS_PREFIX_SWEEP2(idx2)); +#else + const uint32_t h_exc = RS_PREFIX_LOAD(ii * RS_WORKGROUP_SIZE); + + RS_PREFIX_STORE(ii * RS_WORKGROUP_SIZE) = + h_exc + (RS_PREFIX_SWEEP0(idx0) + RS_PREFIX_SWEEP1(idx1) + RS_PREFIX_SWEEP2(idx2)); +#endif + } + +#else +#error "Error: Unsupported subgroup size" +#endif + +#endif +} + +// +// +// + +#endif // SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PREFIX_H_ diff --git a/src/amd/vulkan/radix_sort/shaders/prefix_limits.h b/src/amd/vulkan/radix_sort/shaders/prefix_limits.h new file mode 100644 index 00000000000..a98e554ad4a --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/prefix_limits.h @@ -0,0 +1,48 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PREFIX_LIMITS_H_ +#define SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PREFIX_LIMITS_H_ + +// +// Define various prefix limits +// +#define RS_PREFIX_LIMITS + +// +// Multi-subgroup prefix requires shared memory. +// +#if (RS_WORKGROUP_SUBGROUPS > 1) + +// clang-format off +#define RS_H_COMPONENTS (RS_RADIX_SIZE / RS_WORKGROUP_SIZE) + +#define RS_SWEEP_0_SIZE (RS_RADIX_SIZE / RS_SUBGROUP_SIZE) +#define RS_SWEEP_1_SIZE (RS_SWEEP_0_SIZE / RS_SUBGROUP_SIZE) +#define RS_SWEEP_2_SIZE (RS_SWEEP_1_SIZE / RS_SUBGROUP_SIZE) + +#define RS_SWEEP_SIZE (RS_SWEEP_0_SIZE + RS_SWEEP_1_SIZE + RS_SWEEP_2_SIZE) + +#define RS_S0_PASSES (RS_SWEEP_0_SIZE / RS_WORKGROUP_SIZE) +#define RS_S1_PASSES (RS_SWEEP_1_SIZE / RS_WORKGROUP_SIZE) + +#define RS_SWEEP_0_OFFSET 0 +#define RS_SWEEP_1_OFFSET (RS_SWEEP_0_OFFSET + RS_SWEEP_0_SIZE) +#define RS_SWEEP_2_OFFSET (RS_SWEEP_1_OFFSET + RS_SWEEP_1_SIZE) +// clang-format on + +// +// Single subgroup prefix doesn't use shared memory. +// +#else + +#define RS_SWEEP_SIZE 0 + +#endif + +// +// +// + +#endif // SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PREFIX_LIMITS_H_ diff --git a/src/amd/vulkan/radix_sort/shaders/push.h b/src/amd/vulkan/radix_sort/shaders/push.h new file mode 100644 index 00000000000..76c0806861d --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/push.h @@ -0,0 +1,263 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PUSH_H_ +#define SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PUSH_H_ + +// +// There is a limit to the maximum number of keyvals that can be sorted because +// the top 2 bits in the atomic lookback counters are used as tag bits. +// +#define RS_MAX_KEYVALS ((1 << 30) - 1) + +// +// Right now, the entire implementation is very much dependent on an 8-bit radix +// size. Most of the shaders attempt to honor this defined size but there are +// still a number of places where 256 is assumed. +// +#define RS_RADIX_LOG2 8 +#define RS_RADIX_SIZE (1 << RS_RADIX_LOG2) + +// +// LOOKBACK STATUS FLAGS +// +// The decoupled lookback status flags are stored in the two +// high bits of the count: +// +// 0 31 +// | REDUCTION OR PREFIX COUNT | STATUS | +// +---------------------------+--------+ +// | 30 | 2 | +// +// This limits the keyval extent size to (2^30-1). +// +// Valid status flags are: +// +// EVEN PASS ODD PASS +// ----------------------- ----------------------- +// 0 : invalid 0 : prefix available +// 1 : reduction available 1 : --- +// 2 : prefix available 2 : invalid +// 3 : --- 3 : reduction available +// +// Atomically adding +1 to a "reduction available" status results in a "prefix +// available" status. +// +// clang-format off +#define RS_PARTITION_STATUS_EVEN_INVALID (0u) +#define RS_PARTITION_STATUS_EVEN_REDUCTION (1u) +#define RS_PARTITION_STATUS_EVEN_PREFIX (2u) + +#define RS_PARTITION_STATUS_ODD_INVALID (2u) +#define RS_PARTITION_STATUS_ODD_REDUCTION (3u) +#define RS_PARTITION_STATUS_ODD_PREFIX (0u) +// clang-format on + +// +// Arguments to indirectly launched shaders. +// +// struct rs_indirect_info_dispatch +// { +// u32vec4 pad; +// u32vec4 zero; +// u32vec4 histogram; +// u32vec4 scatter; +// }; +// +// struct rs_indirect_info_fill +// { +// uint32_t block_offset; +// uint32_t dword_offset_min; +// uint32_t dword_offset_max_minus_min; +// uint32_t reserved; // padding for 16 bytes +// }; +// +// struct rs_indirect_info +// { +// rs_indirect_info_fill pad; +// rs_indirect_info_fill zero; +// rs_indirect_info_dispatch dispatch; +// }; +// +#define RS_STRUCT_INDIRECT_INFO_DISPATCH() \ + struct rs_indirect_info_dispatch \ + { \ + RS_STRUCT_MEMBER_STRUCT(u32vec4, pad) \ + RS_STRUCT_MEMBER_STRUCT(u32vec4, zero) \ + RS_STRUCT_MEMBER_STRUCT(u32vec4, histogram) \ + RS_STRUCT_MEMBER_STRUCT(u32vec4, scatter) \ + } + +#define RS_STRUCT_INDIRECT_INFO_FILL() \ + struct rs_indirect_info_fill \ + { \ + RS_STRUCT_MEMBER(uint32_t, block_offset) \ + RS_STRUCT_MEMBER(uint32_t, dword_offset_min) \ + RS_STRUCT_MEMBER(uint32_t, dword_offset_max_minus_min) \ + RS_STRUCT_MEMBER(uint32_t, reserved) \ + } + +#define RS_STRUCT_INDIRECT_INFO() \ + RS_STRUCT_INDIRECT_INFO_DISPATCH(); \ + RS_STRUCT_INDIRECT_INFO_FILL(); \ + struct rs_indirect_info \ + { \ + RS_STRUCT_MEMBER_STRUCT(rs_indirect_info_fill, pad) \ + RS_STRUCT_MEMBER_STRUCT(rs_indirect_info_fill, zero) \ + RS_STRUCT_MEMBER_STRUCT(rs_indirect_info_dispatch, dispatch) \ + } + +// +// Define the push constant structures shared by the host and device. +// +// INIT +// ---- +// struct rs_push_init +// { +// uint64_t devaddr_count; // address of count buffer +// uint64_t devaddr_indirect; // address of indirect info buffer +// }; +// +// FILL +// ---- +// struct rs_push_fill +// { +// uint64_t devaddr_info; // address of indirect info for fill shader +// uint64_t devaddr_dwords; // address of dwords extent +// uint32_t dword; // dword value used to fill the dwords extent +// }; +// +// HISTOGRAM +// --------- +// struct rs_push_histogram +// { +// uint64_t devaddr_histograms; // address of histograms extent +// uint64_t devaddr_keyvals; // address of keyvals extent +// uint32_t passes; // number of passes +// }; +// +// PREFIX +// ------ +// struct rs_push_prefix +// { +// uint64_t devaddr_histograms; // address of histograms extent +// }; +// +// SCATTER +// ------- +// struct rs_push_scatter +// { +// uint64_t devaddr_keyvals_in; // address of input keyvals +// uint64_t devaddr_keyvals_out; // address of output keyvals +// uint64_t devaddr_partitions // address of partitions +// uint64_t devaddr_histogram; // address of pass histogram +// uint32_t pass_offset; // keyval pass offset +// }; +// +#define RS_STRUCT_PUSH_INIT() \ + struct rs_push_init \ + { \ + RS_STRUCT_MEMBER(RS_DEVADDR, devaddr_info) \ + RS_STRUCT_MEMBER(RS_DEVADDR, devaddr_count) \ + RS_STRUCT_MEMBER(uint32_t, passes) \ + } + +#define RS_STRUCT_PUSH_FILL() \ + struct rs_push_fill \ + { \ + RS_STRUCT_MEMBER(RS_DEVADDR, devaddr_info) \ + RS_STRUCT_MEMBER(RS_DEVADDR, devaddr_dwords) \ + RS_STRUCT_MEMBER(uint32_t, dword) \ + } + +#define RS_STRUCT_PUSH_HISTOGRAM() \ + struct rs_push_histogram \ + { \ + RS_STRUCT_MEMBER(RS_DEVADDR, devaddr_histograms) \ + RS_STRUCT_MEMBER(RS_DEVADDR, devaddr_keyvals) \ + RS_STRUCT_MEMBER(uint32_t, passes) \ + } + +#define RS_STRUCT_PUSH_PREFIX() \ + struct rs_push_prefix \ + { \ + RS_STRUCT_MEMBER(RS_DEVADDR, devaddr_histograms) \ + } + +#define RS_STRUCT_PUSH_SCATTER() \ + struct rs_push_scatter \ + { \ + RS_STRUCT_MEMBER(RS_DEVADDR, devaddr_keyvals_even) \ + RS_STRUCT_MEMBER(RS_DEVADDR, devaddr_keyvals_odd) \ + RS_STRUCT_MEMBER(RS_DEVADDR, devaddr_partitions) \ + RS_STRUCT_MEMBER(RS_DEVADDR, devaddr_histograms) \ + RS_STRUCT_MEMBER(uint32_t, pass_offset) \ + } + +//////////////////////////////////////////////////////////////////// +// +// GLSL +// +#ifdef VULKAN // defined by GLSL/VK compiler + +// clang-format off +#define RS_STRUCT_MEMBER(type_, name_) type_ name_; +#define RS_STRUCT_MEMBER_FARRAY(type_, len_, name_) type_ name_[len_]; +#define RS_STRUCT_MEMBER_STRUCT(type_, name_) type_ name_; +// clang-format on + +//////////////////////////////////////////////////////////////////// +// +// C/C++ +// +#else + +#ifdef __cplusplus +extern "C" { +#endif + +// +// +// + +#include + +struct u32vec4 +{ + uint32_t x; + uint32_t y; + uint32_t z; + uint32_t w; +}; + +// clang-format off +#define RS_DEVADDR uint64_t +#define RS_STRUCT_MEMBER(type_, name_) type_ name_; +#define RS_STRUCT_MEMBER_FARRAY(type_, len_, name_) type_ name_[len_]; +#define RS_STRUCT_MEMBER_STRUCT(type_, name_) struct type_ name_; +// clang-format on + +RS_STRUCT_PUSH_INIT(); +RS_STRUCT_PUSH_FILL(); +RS_STRUCT_PUSH_HISTOGRAM(); +RS_STRUCT_PUSH_PREFIX(); +RS_STRUCT_PUSH_SCATTER(); + +RS_STRUCT_INDIRECT_INFO(); + +// +// +// + +#ifdef __cplusplus +} +#endif + +#endif + +// +// +// + +#endif // SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_SHADERS_PUSH_H_ diff --git a/src/amd/vulkan/radix_sort/shaders/scatter.glsl b/src/amd/vulkan/radix_sort/shaders/scatter.glsl new file mode 100644 index 00000000000..b57d9e80850 --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/scatter.glsl @@ -0,0 +1,1706 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// #pragma use_vulkan_memory_model // results in spirv-remap validation error + +// +// Each "pass" scatters the keyvals to their new destinations. +// +// clang-format off +#extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_memory_scope_semantics : require +#extension GL_KHR_shader_subgroup_ballot : require +// clang-format on + +// +// Load arch/keyval configuration +// +#include "config.h" + +// +// Optional switches: +// +// #define RS_SCATTER_DISABLE_REORDER +// #define RS_SCATTER_ENABLE_BITFIELD_EXTRACT +// #define RS_SCATTER_ENABLE_NV_MATCH +// #define RS_SCATTER_ENABLE_BROADCAST_MATCH +// #define RS_SCATTER_DISABLE_COMPONENTS_IN_REGISTERS +// + +// +// Use NVIDIA Turing/Volta+ partitioning operator (`match_any()`)? +// +#ifdef RS_SCATTER_ENABLE_NV_MATCH +#extension GL_NV_shader_subgroup_partitioned : require +#endif + +// +// Store prefix intermediates in registers? +// +#ifdef RS_SCATTER_DISABLE_COMPONENTS_IN_REGISTERS +#define RS_PREFIX_DISABLE_COMPONENTS_IN_REGISTERS +#endif + +// +// Buffer reference macros and push constants +// +#include "bufref.h" +#include "push.h" + +// +// Push constants for scatter shader +// +RS_STRUCT_PUSH_SCATTER(); + +layout(push_constant) uniform block_push +{ + rs_push_scatter push; +}; + +// +// Subgroup uniform support +// +#if defined(RS_SCATTER_SUBGROUP_UNIFORM_DISABLE) && defined(GL_EXT_subgroupuniform_qualifier) +#extension GL_EXT_subgroupuniform_qualifier : required +#define RS_SUBGROUP_UNIFORM subgroupuniformEXT +#else +#define RS_SUBGROUP_UNIFORM +#endif + +// +// Check all mandatory switches are defined +// + +// What's the size of the keyval? +#ifndef RS_KEYVAL_DWORDS +#error "Undefined: RS_KEYVAL_DWORDS" +#endif + +// Which keyval dword does this shader bitfieldExtract() bits? +#ifndef RS_SCATTER_KEYVAL_DWORD_BASE +#error "Undefined: RS_SCATTER_KEYVAL_DWORD_BASE" +#endif + +// +#ifndef RS_SCATTER_BLOCK_ROWS +#error "Undefined: RS_SCATTER_BLOCK_ROWS" +#endif + +// +#ifndef RS_SCATTER_SUBGROUP_SIZE_LOG2 +#error "Undefined: RS_SCATTER_SUBGROUP_SIZE_LOG2" +#endif + +// +#ifndef RS_SCATTER_WORKGROUP_SIZE_LOG2 +#error "Undefined: RS_SCATTER_WORKGROUP_SIZE_LOG2" +#endif + +// +// Status masks are defined differently for the scatter_even and +// scatter_odd shaders. +// +#ifndef RS_PARTITION_STATUS_INVALID +#error "Undefined: RS_PARTITION_STATUS_INVALID" +#endif + +#ifndef RS_PARTITION_STATUS_REDUCTION +#error "Undefined: RS_PARTITION_STATUS_REDUCTION" +#endif + +#ifndef RS_PARTITION_STATUS_PREFIX +#error "Undefined: RS_PARTITION_STATUS_PREFIX" +#endif + +// +// Assumes (RS_RADIX_LOG2 == 8) +// +// Error if this ever changes! +// +#if (RS_RADIX_LOG2 != 8) +#error "Error: (RS_RADIX_LOG2 != 8)" +#endif + +// +// Masks are different for scatter_even/odd. +// +// clang-format off +#define RS_PARTITION_MASK_INVALID (RS_PARTITION_STATUS_INVALID << 30) +#define RS_PARTITION_MASK_REDUCTION (RS_PARTITION_STATUS_REDUCTION << 30) +#define RS_PARTITION_MASK_PREFIX (RS_PARTITION_STATUS_PREFIX << 30) +#define RS_PARTITION_MASK_STATUS 0xC0000000 +#define RS_PARTITION_MASK_COUNT 0x3FFFFFFF +// clang-format on + +// +// Local macros +// +// clang-format off +#define RS_KEYVAL_SIZE (RS_KEYVAL_DWORDS * 4) +#define RS_WORKGROUP_SIZE (1 << RS_SCATTER_WORKGROUP_SIZE_LOG2) +#define RS_SUBGROUP_SIZE (1 << RS_SCATTER_SUBGROUP_SIZE_LOG2) +#define RS_WORKGROUP_SUBGROUPS (RS_WORKGROUP_SIZE / RS_SUBGROUP_SIZE) +#define RS_SUBGROUP_KEYVALS (RS_SCATTER_BLOCK_ROWS * RS_SUBGROUP_SIZE) +#define RS_BLOCK_KEYVALS (RS_SCATTER_BLOCK_ROWS * RS_WORKGROUP_SIZE) +#define RS_RADIX_MASK ((1 << RS_RADIX_LOG2) - 1) +// clang-format on + +// +// Validate number of keyvals fit in a uint16_t. +// +#if (RS_BLOCK_KEYVALS >= 65536) +#error "Error: (RS_BLOCK_KEYVALS >= 65536)" +#endif + +// +// Keyval type +// +#if (RS_KEYVAL_DWORDS == 1) +#define RS_KEYVAL_TYPE uint32_t +#elif (RS_KEYVAL_DWORDS == 2) +#define RS_KEYVAL_TYPE u32vec2 +#else +#error "Error: Unsupported RS_KEYVAL_DWORDS" +#endif + +// +// Set up match mask +// +#if (RS_SUBGROUP_SIZE <= 32) +#if (RS_SUBGROUP_SIZE == 32) +#define RS_SUBGROUP_MASK 0xFFFFFFFF +#else +#define RS_SUBGROUP_MASK ((1 << RS_SUBGROUP_SIZE) - 1) +#endif +#endif + +// +// Determine at compile time the base of the final iteration for +// workgroups smaller than RS_RADIX_SIZE. +// +#if (RS_WORKGROUP_SIZE < RS_RADIX_SIZE) +#define RS_WORKGROUP_BASE_FINAL ((RS_RADIX_SIZE / RS_WORKGROUP_SIZE) * RS_WORKGROUP_SIZE) +#endif + +// +// Max macro +// +#define RS_MAX_2(a_, b_) (((a_) >= (b_)) ? (a_) : (b_)) + +// +// Select a keyval dword +// +#if (RS_KEYVAL_DWORDS == 1) +#define RS_KV_DWORD(kv_, dword_) (kv_) +#else +#define RS_KV_DWORD(kv_, dword_) (kv_)[dword_] +#endif + +// +// Is bitfield extract faster? +// +#ifdef RS_SCATTER_ENABLE_BITFIELD_EXTRACT +//---------------------------------------------------------------------- +// +// Test a bit in a radix digit +// +#define RS_BIT_IS_ONE(val_, bit_) (bitfieldExtract(val_, bit_, 1) != 0) + +// +// Extract a keyval digit +// +#if (RS_KEYVAL_DWORDS == 1) +#define RS_KV_EXTRACT_DIGIT(kv_) bitfieldExtract(kv_, int32_t(push.pass_offset), RS_RADIX_LOG2) +#else +#define RS_KV_EXTRACT_DIGIT(kv_) \ + bitfieldExtract(kv_[RS_SCATTER_KEYVAL_DWORD_BASE], int32_t(push.pass_offset), RS_RADIX_LOG2) +#endif +//---------------------------------------------------------------------- +#else +//---------------------------------------------------------------------- +// +// Test a bit in a radix digit +// +#define RS_BIT_IS_ONE(val_, bit_) (((val_) & (1 << (bit_))) != 0) + +// +// Extract a keyval digit +// +#if (RS_KEYVAL_DWORDS == 1) +#define RS_KV_EXTRACT_DIGIT(kv_) ((kv_ >> push.pass_offset) & RS_RADIX_MASK) +#else +#define RS_KV_EXTRACT_DIGIT(kv_) \ + ((kv_[RS_SCATTER_KEYVAL_DWORD_BASE] >> push.pass_offset) & RS_RADIX_MASK) +#endif +//---------------------------------------------------------------------- +#endif + +// +// Load prefix limits before loading prefix function and before +// calculating SMEM limits. +// +#include "prefix_limits.h" + +// +// - The lookback span is RS_RADIX_SIZE dwords and overwrites the +// ballots span. +// +// - The histogram span is RS_RADIX_SIZE dwords +// +// - The keyvals span is at least one dword per keyval in the +// workgroup. This span overwrites anything past the lookback +// radix span. +// +// Shared memory map phase 1: +// +// < LOOKBACK > < HISTOGRAM > < PREFIX > ... +// +// Shared memory map phase 3: +// +// < LOOKBACK > < REORDER > ... +// +// FIXME(allanmac): Create a spreadsheet showing the exact shared +// memory footprint (RS_SMEM_DWORDS) for a configuration. +// +// | Dwords | Bytes +// ----------+-------------------------------------------+-------- +// Lookback | 256 | 1 KB +// Histogram | 256 | 1 KB +// Prefix | 4-84 | 16-336 +// Reorder | RS_WORKGROUP_SIZE * RS_SCATTER_BLOCK_ROWS | 2-8 KB +// +// clang-format off +#define RS_SMEM_LOOKBACK_SIZE RS_RADIX_SIZE +#define RS_SMEM_HISTOGRAM_SIZE RS_RADIX_SIZE +#define RS_SMEM_REORDER_SIZE (RS_SCATTER_BLOCK_ROWS * RS_WORKGROUP_SIZE) + +#define RS_SMEM_DWORDS_PHASE_1 (RS_SMEM_LOOKBACK_SIZE + RS_SMEM_HISTOGRAM_SIZE + RS_SWEEP_SIZE) +#define RS_SMEM_DWORDS_PHASE_2 (RS_SMEM_LOOKBACK_SIZE + RS_SMEM_REORDER_SIZE) + +#define RS_SMEM_DWORDS RS_MAX_2(RS_SMEM_DWORDS_PHASE_1, RS_SMEM_DWORDS_PHASE_2) + +#define RS_SMEM_LOOKBACK_OFFSET 0 +#define RS_SMEM_HISTOGRAM_OFFSET (RS_SMEM_LOOKBACK_OFFSET + RS_SMEM_LOOKBACK_SIZE) +#define RS_SMEM_PREFIX_OFFSET (RS_SMEM_HISTOGRAM_OFFSET + RS_SMEM_HISTOGRAM_SIZE) +#define RS_SMEM_REORDER_OFFSET (RS_SMEM_LOOKBACK_OFFSET + RS_SMEM_LOOKBACK_SIZE) +// clang-format on + +// +// +// +layout(local_size_x = RS_WORKGROUP_SIZE) in; + +// +// +// +layout(buffer_reference, std430) buffer buffer_rs_kv +{ + RS_KEYVAL_TYPE extent[]; +}; + +layout(buffer_reference, std430) buffer buffer_rs_histogram // single histogram +{ + uint32_t extent[]; +}; + +layout(buffer_reference, std430) buffer buffer_rs_partitions +{ + uint32_t extent[]; +}; + +// +// Declare shared memory +// +struct rs_scatter_smem +{ + uint32_t extent[RS_SMEM_DWORDS]; +}; + +shared rs_scatter_smem smem; + +// +// The shared memory barrier is either subgroup-wide or +// workgroup-wide. +// +#if (RS_WORKGROUP_SUBGROUPS == 1) +#define RS_BARRIER() subgroupBarrier() +#else +#define RS_BARRIER() barrier() +#endif + +// +// If multi-subgroup then define shared memory +// +#if (RS_WORKGROUP_SUBGROUPS > 1) + +//---------------------------------------- +#define RS_PREFIX_SWEEP0(idx_) smem.extent[RS_SMEM_PREFIX_OFFSET + RS_SWEEP_0_OFFSET + (idx_)] +//---------------------------------------- + +#if (RS_SWEEP_1_SIZE > 0) +//---------------------------------------- +#define RS_PREFIX_SWEEP1(idx_) smem.extent[RS_SMEM_PREFIX_OFFSET + RS_SWEEP_1_OFFSET + (idx_)] +//---------------------------------------- +#endif + +#if (RS_SWEEP_2_SIZE > 0) +//---------------------------------------- +#define RS_PREFIX_SWEEP2(idx_) smem.extent[RS_SMEM_PREFIX_OFFSET + RS_SWEEP_2_OFFSET + (idx_)] +//---------------------------------------- +#endif + +#endif + +// +// Define prefix load/store functions +// +// clang-format off +#if (RS_WORKGROUP_SUBGROUPS == 1) +#define RS_PREFIX_LOAD(idx_) smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_SubgroupInvocationID + (idx_)] +#define RS_PREFIX_STORE(idx_) smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_SubgroupInvocationID + (idx_)] +#else +#define RS_PREFIX_LOAD(idx_) smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x + (idx_)] +#define RS_PREFIX_STORE(idx_) smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x + (idx_)] +#endif +// clang-format on + +// +// Load the prefix function +// +// The prefix function operates on shared memory so there are no +// arguments. +// +#define RS_PREFIX_ARGS // EMPTY + +#include "prefix.h" + +// +// Zero the SMEM histogram +// +void +rs_histogram_zero() +{ +#if (RS_WORKGROUP_SUBGROUPS == 1) + + const uint32_t smem_offset = RS_SMEM_HISTOGRAM_OFFSET + gl_SubgroupInvocationID; + + [[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE) + { + smem.extent[smem_offset + ii] = 0; + } + +#elif (RS_WORKGROUP_SIZE < RS_RADIX_SIZE) + + const uint32_t smem_offset = RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x; + + [[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE) + { + smem.extent[smem_offset + ii] = 0; + } + +#if (RS_WORKGROUP_BASE_FINAL < RS_RADIX_SIZE) + const uint32_t smem_offset_final = smem_offset + RS_WORKGROUP_BASE_FINAL; + + if (smem_offset_final < RS_RADIX_SIZE) + { + smem.histogram[smem_offset_final] = 0; + } +#endif + +#elif (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE) + +#if (RS_WORKGROUP_SIZE > RS_RADIX_SIZE) + if (gl_LocalInvocationID.x < RS_RADIX_SIZE) +#endif + { + smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x] = 0; + } + +#endif + + RS_BARRIER(); +} + +// +// Perform a workgroup-wide match operation that computes both a +// workgroup-wide index for each keyval and a workgroup-wide +// histogram. +// +// FIXME(allanmac): Special case (RS_WORKGROUP_SUBGROUPS==1) +// +void +rs_histogram_rank(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], + out uint32_t kr[RS_SCATTER_BLOCK_ROWS]) +{ + // clang-format off +#define RS_HISTOGRAM_LOAD(digit_) smem.extent[RS_SMEM_HISTOGRAM_OFFSET + (digit_)] +#define RS_HISTOGRAM_STORE(digit_, count_) smem.extent[RS_SMEM_HISTOGRAM_OFFSET + (digit_)] = (count_) + // clang-format on + + //---------------------------------------------------------------------- + // + // Use the Volta/Turing `match.sync` instruction. + // + // Note that performance is quite poor and the break-even for + // `match.sync` requires more bits. + // + //---------------------------------------------------------------------- +#ifdef RS_SCATTER_ENABLE_NV_MATCH + + // + // 32 + // +#if (RS_SUBGROUP_SIZE == 32) + + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + // + // NOTE(allanmac): Unfortunately there is no `match.any.sync.b8` + // + // TODO(allanmac): Consider using the `atomicOr()` match approach + // described by Adinets since Volta/Turing have extremely fast + // atomic smem operations. + // + const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]); + const uint32_t match = subgroupPartitionNV(digit).x; + + kr[ii] = (bitCount(match) << 16) | bitCount(match & gl_SubgroupLeMask.x); + } + + // + // Undefined! + // +#else +#error "Error: rs_histogram_rank() undefined for subgroup size" +#endif + + //---------------------------------------------------------------------- + // + // Default is to emulate a `match` operation with ballots. + // + //---------------------------------------------------------------------- +#elif !defined(RS_SCATTER_ENABLE_BROADCAST_MATCH) + + // + // 64 + // +#if (RS_SUBGROUP_SIZE == 64) + + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]); + + u32vec2 match; + + { + const bool is_one = RS_BIT_IS_ONE(digit, 0); + const u32vec2 ballot = subgroupBallot(is_one).xy; + const uint32_t mask = is_one ? 0 : 0xFFFFFFFF; + + match.x = (ballot.x ^ mask); + match.y = (ballot.y ^ mask); + } + + [[unroll]] for (int32_t bit = 1; bit < RS_RADIX_LOG2; bit++) + { + const bool is_one = RS_BIT_IS_ONE(digit, bit); + const u32vec2 ballot = subgroupBallot(is_one).xy; + const uint32_t mask = is_one ? 0 : 0xFFFFFFFF; + + match.x &= (ballot.x ^ mask); + match.y &= (ballot.y ^ mask); + } + + kr[ii] = ((bitCount(match.x) + bitCount(match.y)) << 16) | + (bitCount(match.x & gl_SubgroupLeMask.x) + // + bitCount(match.y & gl_SubgroupLeMask.y)); + } + + // + // <= 32 + // +#elif ((RS_SUBGROUP_SIZE <= 32) && !defined(RS_SCATTER_ENABLE_NV_MATCH)) + + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]); + + uint32_t match; + + { + const bool is_one = RS_BIT_IS_ONE(digit, 0); + const uint32_t ballot = subgroupBallot(is_one).x; + const uint32_t mask = is_one ? 0 : RS_SUBGROUP_MASK; + + match = (ballot ^ mask); + } + + [[unroll]] for (int32_t bit = 1; bit < RS_RADIX_LOG2; bit++) + { + const bool is_one = RS_BIT_IS_ONE(digit, bit); + const uint32_t ballot = subgroupBallot(is_one).x; + const uint32_t mask = is_one ? 0 : RS_SUBGROUP_MASK; + + match &= (ballot ^ mask); + } + + kr[ii] = (bitCount(match) << 16) | bitCount(match & gl_SubgroupLeMask.x); + } + + // + // Undefined! + // +#else +#error "Error: rs_histogram_rank() undefined for subgroup size" +#endif + + //---------------------------------------------------------------------- + // + // Emulate a `match` operation with broadcasts. + // + // In general, using broadcasts is a win for narrow subgroups. + // + //---------------------------------------------------------------------- +#else + + // + // 64 + // +#if (RS_SUBGROUP_SIZE == 64) + + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]); + + u32vec2 match; + + // subgroup invocation 0 + { + match[0] = (subgroupBroadcast(digit, 0) == digit) ? (1u << 0) : 0; + } + + // subgroup invocations 1-31 + [[unroll]] for (int32_t jj = 1; jj < 32; jj++) + { + match[0] |= (subgroupBroadcast(digit, jj) == digit) ? (1u << jj) : 0; + } + + // subgroup invocation 32 + { + match[1] = (subgroupBroadcast(digit, 32) == digit) ? (1u << 0) : 0; + } + + // subgroup invocations 33-63 + [[unroll]] for (int32_t jj = 1; jj < 32; jj++) + { + match[1] |= (subgroupBroadcast(digit, jj) == digit) ? (1u << jj) : 0; + } + + kr[ii] = ((bitCount(match.x) + bitCount(match.y)) << 16) | + (bitCount(match.x & gl_SubgroupLeMask.x) + // + bitCount(match.y & gl_SubgroupLeMask.y)); + } + + // + // <= 32 + // +#elif ((RS_SUBGROUP_SIZE <= 32) && !defined(RS_SCATTER_ENABLE_NV_MATCH)) + + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]); + + // subgroup invocation 0 + uint32_t match = (subgroupBroadcast(digit, 0) == digit) ? (1u << 0) : 0; + + // subgroup invocations 1-(RS_SUBGROUP_SIZE-1) + [[unroll]] for (int32_t jj = 1; jj < RS_SUBGROUP_SIZE; jj++) + { + match |= (subgroupBroadcast(digit, jj) == digit) ? (1u << jj) : 0; + } + + kr[ii] = (bitCount(match) << 16) | bitCount(match & gl_SubgroupLeMask.x); + } + + // + // Undefined! + // +#else +#error "Error: rs_histogram_rank() undefined for subgroup size" +#endif + +#endif + + // + // This is a little unconventional but cycling through a subgroup at + // a time is a performance win on the tested architectures. + // + for (uint32_t ii = 0; ii < RS_WORKGROUP_SUBGROUPS; ii++) + { + if (gl_SubgroupID == ii) + { + [[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++) + { + const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[jj]); + const uint32_t prev = RS_HISTOGRAM_LOAD(digit); + const uint32_t rank = kr[jj] & 0xFFFF; + const uint32_t count = kr[jj] >> 16; + + kr[jj] = prev + rank; + + if (rank == count) + { + RS_HISTOGRAM_STORE(digit, (prev + count)); + } + + subgroupMemoryBarrierShared(); + } + } + + RS_BARRIER(); + } +} + +// +// Other partitions may lookback on this partition. +// +// Load the global exclusive prefix and for each subgroup +// store the exclusive prefix to shared memory and store the +// final inclusive prefix to global memory. +// +void +rs_first_prefix_store(restrict buffer_rs_partitions rs_partitions) +{ + // + // Define the histogram reference + // +#if (RS_WORKGROUP_SUBGROUPS == 1) + const uint32_t hist_offset = gl_SubgroupInvocationID * 4; +#else + const uint32_t hist_offset = gl_LocalInvocationID.x * 4; +#endif + + readonly RS_BUFREF_DEFINE_AT_OFFSET_UINT32(buffer_rs_histogram, + rs_histogram, + push.devaddr_histograms, + hist_offset); + +#if (RS_WORKGROUP_SUBGROUPS == 1) + //////////////////////////////////////////////////////////////////////////// + // + // (RS_WORKGROUP_SUBGROUPS == 1) + // + const uint32_t smem_offset_h = RS_SMEM_HISTOGRAM_OFFSET + gl_SubgroupInvocationID; + const uint32_t smem_offset_l = RS_SMEM_LOOKBACK_OFFSET + gl_SubgroupInvocationID; + + [[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE) + { + const uint32_t exc = rs_histogram.extent[ii]; + const uint32_t red = smem.extent[smem_offset_h + ii]; + + smem.extent[smem_offset_l + ii] = exc; + + const uint32_t inc = exc + red; + + atomicStore(rs_partitions.extent[ii], + inc | RS_PARTITION_MASK_PREFIX, + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsRelease); + } + +#elif (RS_WORKGROUP_SIZE < RS_RADIX_SIZE) + //////////////////////////////////////////////////////////////////////////// + // + // (RS_WORKGROUP_SIZE < RS_RADIX_SIZE) + // + const uint32_t smem_offset_h = RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x; + const uint32_t smem_offset_l = RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x; + + [[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE) + { + const uint32_t exc = rs_histogram.extent[ii]; + const uint32_t red = smem.extent[smem_offset_h + ii]; + + smem.extent[smem_offset_l + ii] = exc; + + const uint32_t inc = exc + red; + + atomicStore(rs_partitions.extent[ii], + inc | RS_PARTITION_MASK_PREFIX, + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsRelease); + } + +#if (RS_WORKGROUP_BASE_FINAL < RS_RADIX_SIZE) + const uint32_t smem_offset_final_h = smem_offset_h + RS_WORKGROUP_BASE_FINAL; + const uint32_t smem_offset_final_l = smem_offset_l + RS_WORKGROUP_BASE_FINAL; + + if (smem_offset_final < RS_RADIX_SIZE) + { + const uint32_t exc = rs_histogram.extent[RS_WORKGROUP_BASE_FINAL]; + const uint32_t red = smem.extent[smem_offset_final_h]; + + smem.extent[smem_offset_final_l] = exc; + + const uint32_t inc = exc + red; + + atomicStore(rs_partitions.extent[RS_WORKGROUP_BASE_FINAL], + inc | RS_PARTITION_MASK_PREFIX, + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsRelease); + } +#endif + +#elif (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE) + //////////////////////////////////////////////////////////////////////////// + // + // (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE) + // +#if (RS_WORKGROUP_SIZE > RS_RADIX_SIZE) + if (gl_LocalInvocationID.x < RS_RADIX_SIZE) +#endif + { + const uint32_t exc = rs_histogram.extent[0]; + const uint32_t red = smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x]; + + smem.extent[RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x] = exc; + + const uint32_t inc = exc + red; + + atomicStore(rs_partitions.extent[0], + inc | RS_PARTITION_MASK_PREFIX, + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsRelease); + } + +#endif +} + +// +// Atomically store the reduction to the global partition. +// +void +rs_reduction_store(restrict buffer_rs_partitions rs_partitions, + RS_SUBGROUP_UNIFORM const uint32_t partition_base) +{ +#if (RS_WORKGROUP_SUBGROUPS == 1) + //////////////////////////////////////////////////////////////////////////// + // + // (RS_WORKGROUP_SUBGROUPS == 1) + // + const uint32_t smem_offset = RS_SMEM_HISTOGRAM_OFFSET + gl_SubgroupInvocationID; + + [[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE) + { + const uint32_t red = smem.extent[smem_offset + ii]; + + atomicStore(rs_partitions.extent[partition_base + ii], + red | RS_PARTITION_MASK_REDUCTION, + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsRelease); + } + +#elif (RS_WORKGROUP_SIZE < RS_RADIX_SIZE) + //////////////////////////////////////////////////////////////////////////// + // + // (RS_WORKGROUP_SIZE < RS_RADIX_SIZE) + // + const uint32_t smem_offset = RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x; + + [[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE) + { + const uint32_t red = smem.extent[smem_offset + ii]; + + atomicStore(rs_partitions.extent[partition_base + ii], + red | RS_PARTITION_MASK_REDUCTION, + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsRelease); + } + +#if (RS_WORKGROUP_BASE_FINAL < RS_RADIX_SIZE) + const uint32_t smem_offset_final = smem_offset + RS_WORKGROUP_BASE_FINAL; + + if (smem_offset_final < RS_RADIX_SIZE) + { + const uint32_t red = smem.extent[smem_offset_final]; + + atomicStore(rs_partitions.extent[partition_base + RS_WORKGROUP_BASE_FINAL], + red | RS_PARTITION_MASK_REDUCTION, + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsRelease); + } +#endif + +#elif (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE) + //////////////////////////////////////////////////////////////////////////// + // + // (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE) + // +#if (RS_WORKGROUP_SIZE > RS_RADIX_SIZE) + if (gl_LocalInvocationID.x < RS_RADIX_SIZE) +#endif + { + const uint32_t red = smem.extent[RS_SMEM_HISTOGRAM_OFFSET + gl_LocalInvocationID.x]; + + atomicStore(rs_partitions.extent[partition_base], + red | RS_PARTITION_MASK_REDUCTION, + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsRelease); + } + +#endif +} + +// +// Lookback and accumulate reductions until a PREFIX partition is +// reached and then update this workgroup's partition and local +// histogram prefix. +// +// TODO(allanmac): Consider reenabling the cyclic/ring buffer of +// partitions in order to save memory. It actually adds complexity +// but reduces the amount of pre-scatter buffer zeroing. +// +void +rs_lookback_store(restrict buffer_rs_partitions rs_partitions, + RS_SUBGROUP_UNIFORM const uint32_t partition_base) +{ +#if (RS_WORKGROUP_SUBGROUPS == 1) + //////////////////////////////////////////////////////////////////////////// + // + // (RS_WORKGROUP_SUBGROUPS == 1) + // + const uint32_t smem_offset = RS_SMEM_LOOKBACK_OFFSET + gl_SubgroupInvocationID; + + [[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE) + { + uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE; + uint32_t exc = 0; + + // + // NOTE: Each workgroup invocation can proceed independently. + // Subgroups and workgroups do NOT have to coordinate. + // + while (true) + { + const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev + ii], + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsAcquire); + + // spin until valid + if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID) + { + continue; + } + + exc += (prev & RS_PARTITION_MASK_COUNT); + + if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX) + { + // continue accumulating reductions + partition_base_prev -= RS_RADIX_SIZE; + continue; + } + + // + // Otherwise, save the exclusive scan and atomically transform + // the reduction into an inclusive prefix status math: + // + // reduction + 1 = prefix + // + smem.extent[smem_offset + ii] = exc; + + atomicAdd(rs_partitions.extent[partition_base + ii], + exc | (1 << 30), + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsAcquireRelease); + break; + } + } + +#elif (RS_WORKGROUP_SIZE < RS_RADIX_SIZE) + //////////////////////////////////////////////////////////////////////////// + // + // (RS_WORKGROUP_SIZE < RS_RADIX_SIZE) + // + const uint32_t smem_offset = RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x; + + [[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE) + { + uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE; + uint32_t exc = 0; + + // + // NOTE: Each workgroup invocation can proceed independently. + // Subgroups and workgroups do NOT have to coordinate. + // + while (true) + { + const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev + ii], + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsAcquire); + + // spin until valid + if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID) + { + continue; + } + + exc += (prev & RS_PARTITION_MASK_COUNT); + + if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX) + { + // continue accumulating reductions + partition_base_prev -= RS_RADIX_SIZE; + continue; + } + + // + // Otherwise, save the exclusive scan and atomically transform + // the reduction into an inclusive prefix status math: + // + // reduction + 1 = prefix + // + smem.extent[smem_offset + ii] = exc; + + atomicAdd(rs_partitions.extent[partition_base + ii], + exc | (1 << 30), + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsAcquireRelease); + break; + } + } + +#if (RS_WORKGROUP_BASE_FINAL < RS_RADIX_SIZE) + const uint32_t smem_offset_final = smem_offset + RS_WORKGROUP_BASE_FINAL; + + if (smem_offset_final < RS_SMEM_LOOKBACK_OFFSET + RS_RADIX_SIZE) + { + uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE; + uint32_t exc = 0; + + // + // NOTE: Each workgroup invocation can proceed independently. + // Subgroups and workgroups do NOT have to coordinate. + // + while (true) + { + const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev + ii], + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsAcquire); + + // spin until valid + if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID) + { + continue; + } + + exc += (prev & RS_PARTITION_MASK_COUNT); + + if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX) + { + // continue accumulating reductions + partition_base_prev -= RS_RADIX_SIZE; + continue; + } + + // + // Otherwise, save the exclusive scan and atomically transform + // the reduction into an inclusive prefix status math: + // + // reduction + 1 = prefix + // + smem.extent[smem_offset + ii] = exc; + + atomicAdd(rs_partitions.extent[partition_base + RS_WORKGROUP_BASE_FINAL], + exc | (1 << 30), + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsAcquireRelease); + break; + } + } +#endif + +#elif (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE) + //////////////////////////////////////////////////////////////////////////// + // + // (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE) + // +#if (RS_WORKGROUP_SIZE > RS_RADIX_SIZE) + if (gl_LocalInvocationID.x < RS_RADIX_SIZE) +#endif + { + uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE; + uint32_t exc = 0; + + // + // NOTE: Each workgroup invocation can proceed independently. + // Subgroups and workgroups do NOT have to coordinate. + // + while (true) + { + const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev], + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsAcquire); + + // spin until valid + if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID) + { + continue; + } + + exc += (prev & RS_PARTITION_MASK_COUNT); + + if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX) + { + // continue accumulating reductions + partition_base_prev -= RS_RADIX_SIZE; + continue; + } + + // + // Otherwise, save the exclusive scan and atomically transform + // the reduction into an inclusive prefix status math: + // + // reduction + 1 = prefix + // + smem.extent[RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x] = exc; + + atomicAdd(rs_partitions.extent[partition_base], + exc | (1 << 30), + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsAcquireRelease); + break; + } + } + +#endif +} + +// +// Lookback and accumulate reductions until a PREFIX partition is +// reached and then update this workgroup's local histogram prefix. +// +// Skip updating this workgroup's partition because it's last. +// +void +rs_lookback_skip_store(restrict buffer_rs_partitions rs_partitions, + RS_SUBGROUP_UNIFORM const uint32_t partition_base) +{ +#if (RS_WORKGROUP_SUBGROUPS == 1) + //////////////////////////////////////////////////////////////////////////// + // + // (RS_WORKGROUP_SUBGROUPS == 1) + // + const uint32_t smem_offset = RS_SMEM_LOOKBACK_OFFSET + gl_SubgroupInvocationID; + + [[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_SUBGROUP_SIZE) + { + uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE; + uint32_t exc = 0; + + // + // NOTE: Each workgroup invocation can proceed independently. + // Subgroups and workgroups do NOT have to coordinate. + // + while (true) + { + const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev + ii], + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsAcquire); + + // spin until valid + if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID) + { + continue; + } + + exc += (prev & RS_PARTITION_MASK_COUNT); + + if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX) + { + // continue accumulating reductions + partition_base_prev -= RS_RADIX_SIZE; + continue; + } + + // Otherwise, save the exclusive scan. + smem.extent[smem_offset + ii] = exc; + break; + } + } + +#elif (RS_WORKGROUP_SIZE < RS_RADIX_SIZE) + //////////////////////////////////////////////////////////////////////////// + // + // (RS_WORKGROUP_SIZE < RS_RADIX_SIZE) + // + const uint32_t smem_offset = RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x; + + [[unroll]] for (uint32_t ii = 0; ii < RS_RADIX_SIZE; ii += RS_WORKGROUP_SIZE) + { + uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE; + uint32_t exc = 0; + + // + // NOTE: Each workgroup invocation can proceed independently. + // Subgroups and workgroups do NOT have to coordinate. + // + while (true) + { + const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev + ii], + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsAcquire); + + // spin until valid + if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID) + { + continue; + } + + exc += (prev & RS_PARTITION_MASK_COUNT); + + if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX) + { + // continue accumulating reductions + partition_base_prev -= RS_RADIX_SIZE; + continue; + } + + // Otherwise, save the exclusive scan. + smem.extent[smem_offset + ii] = exc; + break; + } + } + +#if (RS_WORKGROUP_BASE_FINAL < RS_RADIX_SIZE) + const uint32_t smem_offset_final = smem_offset + RS_WORKGROUP_BASE_FINAL; + + if (smem_offset_final < RS_RADIX_SIZE) + { + uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE; + uint32_t exc = 0; + + // + // NOTE: Each workgroup invocation can proceed independently. + // Subgroups and workgroups do NOT have to coordinate. + // + while (true) + { + const uint32_t prev = + atomicLoad(rs_partitions.extent[partition_base_prev + RS_WORKGROUP_BASE_FINAL], + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsAcquire); + + // spin until valid + if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID) + { + continue; + } + + exc += (prev & RS_PARTITION_MASK_COUNT); + + if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX) + { + // continue accumulating reductions + partition_base_prev -= RS_RADIX_SIZE; + continue; + } + + // Otherwise, save the exclusive scan. + smem.extent[smem_offset_final] = exc; + break; + } + } +#endif + +#elif (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE) + //////////////////////////////////////////////////////////////////////////// + // + // (RS_WORKGROUP_SIZE >= RS_RADIX_SIZE) + // +#if (RS_WORKGROUP_SIZE > RS_RADIX_SIZE) + if (gl_LocalInvocationID.x < RS_RADIX_SIZE) +#endif + { + uint32_t partition_base_prev = partition_base - RS_RADIX_SIZE; + uint32_t exc = 0; + + // + // NOTE: Each workgroup invocation can proceed independently. + // Subgroups and workgroups do NOT have to coordinate. + // + while (true) + { + const uint32_t prev = atomicLoad(rs_partitions.extent[partition_base_prev], + gl_ScopeQueueFamily, + gl_StorageSemanticsBuffer, + gl_SemanticsAcquire); + + // spin until valid + if ((prev & RS_PARTITION_MASK_STATUS) == RS_PARTITION_MASK_INVALID) + { + continue; + } + + exc += (prev & RS_PARTITION_MASK_COUNT); + + if ((prev & RS_PARTITION_MASK_STATUS) != RS_PARTITION_MASK_PREFIX) + { + // continue accumulating reductions + partition_base_prev -= RS_RADIX_SIZE; + continue; + } + + // Otherwise, save the exclusive scan. + smem.extent[RS_SMEM_LOOKBACK_OFFSET + gl_LocalInvocationID.x] = exc; + break; + } + } + +#endif +} + +// +// Compute a 1-based local index for each keyval by adding the 1-based +// rank to the local histogram prefix. +// +void +rs_rank_to_local(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], + inout uint32_t kr[RS_SCATTER_BLOCK_ROWS]) +{ + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]); + const uint32_t exc = smem.extent[RS_SMEM_HISTOGRAM_OFFSET + digit]; + const uint32_t idx = exc + kr[ii]; + + kr[ii] |= (idx << 16); + } + + // + // Reordering phase will overwrite histogram span. + // + RS_BARRIER(); +} + +// +// Compute a 1-based local index for each keyval by adding the 1-based +// rank to the global histogram prefix. +// +void +rs_rank_to_global(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], + inout uint32_t kr[RS_SCATTER_BLOCK_ROWS]) +{ + // + // Define the histogram reference + // + readonly RS_BUFREF_DEFINE(buffer_rs_histogram, rs_histogram, push.devaddr_histograms); + + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]); + const uint32_t exc = rs_histogram.extent[digit]; + + kr[ii] += (exc - 1); + } +} + +// +// Using the local indices, rearrange the keyvals into sorted order. +// +void +rs_reorder(inout RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], inout uint32_t kr[RS_SCATTER_BLOCK_ROWS]) +{ + // clang-format off +#if (RS_WORKGROUP_SUBGROUPS == 1) + const uint32_t smem_base = RS_SMEM_REORDER_OFFSET + gl_SubgroupInvocationID; +#else + const uint32_t smem_base = RS_SMEM_REORDER_OFFSET + gl_LocalInvocationID.x; +#endif + // clang-format on + + [[unroll]] for (uint32_t ii = 0; ii < RS_KEYVAL_DWORDS; ii++) + { + // + // Store keyval dword to sorted location + // + [[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++) + { + const uint32_t smem_idx = (RS_SMEM_REORDER_OFFSET - 1) + (kr[jj] >> 16); + + smem.extent[smem_idx] = RS_KV_DWORD(kv[jj], ii); + } + + RS_BARRIER(); + + // + // Load keyval dword from sorted location + // + [[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++) + { + RS_KV_DWORD(kv[jj], ii) = smem.extent[smem_base + jj * RS_WORKGROUP_SIZE]; + } + + RS_BARRIER(); + } + + // + // Store the digit-index to sorted location + // + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + const uint32_t smem_idx = (RS_SMEM_REORDER_OFFSET - 1) + (kr[ii] >> 16); + + smem.extent[smem_idx] = uint32_t(kr[ii]); + } + + RS_BARRIER(); + + // + // Load kr[] from sorted location -- we only need the rank. + // + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + kr[ii] = smem.extent[smem_base + ii * RS_WORKGROUP_SIZE] & 0xFFFF; + } +} + +// +// Using the global/local indices obtained by a single workgroup, +// rearrange the keyvals into sorted order. +// +void +rs_reorder_1(inout RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], + inout uint32_t kr[RS_SCATTER_BLOCK_ROWS]) +{ + // clang-format off +#if (RS_WORKGROUP_SUBGROUPS == 1) + const uint32_t smem_base = RS_SMEM_REORDER_OFFSET + gl_SubgroupInvocationID; +#else + const uint32_t smem_base = RS_SMEM_REORDER_OFFSET + gl_LocalInvocationID.x; +#endif + // clang-format on + + [[unroll]] for (uint32_t ii = 0; ii < RS_KEYVAL_DWORDS; ii++) + { + // + // Store keyval dword to sorted location + // + [[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++) + { + const uint32_t smem_idx = RS_SMEM_REORDER_OFFSET + kr[jj]; + + smem.extent[smem_idx] = RS_KV_DWORD(kv[jj], ii); + } + + RS_BARRIER(); + + // + // Load keyval dword from sorted location + // + [[unroll]] for (uint32_t jj = 0; jj < RS_SCATTER_BLOCK_ROWS; jj++) + { + RS_KV_DWORD(kv[jj], ii) = smem.extent[smem_base + jj * RS_WORKGROUP_SIZE]; + } + + RS_BARRIER(); + } + + // + // Store the digit-index to sorted location + // + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + const uint32_t smem_idx = RS_SMEM_REORDER_OFFSET + kr[ii]; + + smem.extent[smem_idx] = uint32_t(kr[ii]); + } + + RS_BARRIER(); + + // + // Load kr[] from sorted location -- we only need the rank. + // + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + kr[ii] = smem.extent[smem_base + ii * RS_WORKGROUP_SIZE]; + } +} + +// +// Each subgroup loads RS_SCATTER_BLOCK_ROWS rows of keyvals into +// registers. +// +void +rs_load(out RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS]) +{ + // + // Set up buffer reference + // + const uint32_t kv_in_offset_keys = gl_WorkGroupID.x * RS_BLOCK_KEYVALS + + gl_SubgroupID * RS_SUBGROUP_KEYVALS + gl_SubgroupInvocationID; + + u32vec2 kv_in_offset; + + umulExtended(kv_in_offset_keys, + RS_KEYVAL_SIZE, + kv_in_offset.y, // msb + kv_in_offset.x); // lsb + + readonly RS_BUFREF_DEFINE_AT_OFFSET_U32VEC2(buffer_rs_kv, + rs_kv_in, + RS_DEVADDR_KEYVALS_IN(push), + kv_in_offset); + + // + // Load keyvals + // + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + kv[ii] = rs_kv_in.extent[ii * RS_SUBGROUP_SIZE]; + } +} + +// +// Convert local index to global +// +void +rs_local_to_global(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], + inout uint32_t kr[RS_SCATTER_BLOCK_ROWS]) +{ + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + const uint32_t digit = RS_KV_EXTRACT_DIGIT(kv[ii]); + const uint32_t exc = smem.extent[RS_SMEM_LOOKBACK_OFFSET + digit]; + + kr[ii] += (exc - 1); + } +} + +// +// Store a single workgroup +// +void +rs_store(const RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS], const uint32_t kr[RS_SCATTER_BLOCK_ROWS]) +{ + // + // Define kv_out bufref + // + writeonly RS_BUFREF_DEFINE(buffer_rs_kv, rs_kv_out, RS_DEVADDR_KEYVALS_OUT(push)); + + // + // Store keyval: + // + // "out[ keyval.rank ] = keyval" + // + // FIXME(allanmac): Consider implementing an aligned writeout + // strategy to avoid excess global memory transactions. + // + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + rs_kv_out.extent[kr[ii]] = kv[ii]; + } +} + +// +// +// +void +main() +{ + // + // Load keyvals + // + RS_KEYVAL_TYPE kv[RS_SCATTER_BLOCK_ROWS]; + + rs_load(kv); + + // + // Zero shared histogram + // + // Ends with barrier. + // + rs_histogram_zero(); + + // + // Compute histogram and bin-relative keyval indices + // + // This histogram can immediately be used to update the partition + // with either a PREFIX or REDUCTION flag. + // + // Ends with a barrier. + // + uint32_t kr[RS_SCATTER_BLOCK_ROWS]; + + rs_histogram_rank(kv, kr); + +// +// DEBUG +// +#if 0 // (RS_KEYVAL_DWORDS == 1) + { + writeonly RS_BUFREF_DEFINE_AT_OFFSET_UINT32(buffer_rs_kv, + rs_kv_out, + RS_DEVADDR_KEYVALS_OUT(push), + gl_LocalInvocationID.x * 4); + + [[unroll]] for (uint32_t ii = 0; ii < RS_SCATTER_BLOCK_ROWS; ii++) + { + rs_kv_out.extent[gl_WorkGroupID.x * RS_BLOCK_KEYVALS + ii * RS_WORKGROUP_SIZE] = kr[ii]; + } + + return; + } +#endif + + // + // When there is a single workgroup then the local and global + // exclusive scanned histograms are the same. + // + if (gl_NumWorkGroups.x == 1) + { + rs_rank_to_global(kv, kr); + +#ifndef RS_SCATTER_DISABLE_REORDER + rs_reorder_1(kv, kr); +#endif + + rs_store(kv, kr); + } + else + { + // + // Define partitions bufref + // +#if (RS_WORKGROUP_SUBGROUPS == 1) + const uint32_t partition_offset = gl_SubgroupInvocationID * 4; +#else + const uint32_t partition_offset = gl_LocalInvocationID.x * 4; +#endif + + RS_BUFREF_DEFINE_AT_OFFSET_UINT32(buffer_rs_partitions, + rs_partitions, + push.devaddr_partitions, + partition_offset); + + // + // The first partition is a special case. + // + if (gl_WorkGroupID.x == 0) + { + // + // Other workgroups may lookback on this partition. + // + // Load the global histogram and local histogram and store + // the exclusive prefix. + // + rs_first_prefix_store(rs_partitions); + } + else + { + // + // Otherwise, this is not the first workgroup. + // + RS_SUBGROUP_UNIFORM const uint32_t partition_base = gl_WorkGroupID.x * RS_RADIX_SIZE; + + // + // The last partition is a special case. + // + if (gl_WorkGroupID.x + 1 < gl_NumWorkGroups.x) + { + // + // Atomically store the reduction to the global partition. + // + rs_reduction_store(rs_partitions, partition_base); + + // + // Lookback and accumulate reductions until a PREFIX + // partition is reached and then update this workgroup's + // partition and local histogram prefix. + // + rs_lookback_store(rs_partitions, partition_base); + } + else + { + // + // Lookback and accumulate reductions until a PREFIX + // partition is reached and then update this workgroup's + // local histogram prefix. + // + // Skip updating this workgroup's partition because it's + // last. + // + rs_lookback_skip_store(rs_partitions, partition_base); + } + } + +#ifndef RS_SCATTER_DISABLE_REORDER + // + // Compute exclusive prefix scan of histogram. + // + // No barrier. + // + rs_prefix(); + + // + // Barrier before reading prefix scanned histogram. + // + RS_BARRIER(); + + // + // Convert keyval's rank to a local index + // + // Ends with a barrier. + // + rs_rank_to_local(kv, kr); + + // + // Reorder kv[] and kr[] + // + // Ends with a barrier. + // + rs_reorder(kv, kr); +#else + // + // Wait for lookback to complete. + // + RS_BARRIER(); +#endif + + // + // Convert local index to a global index. + // + rs_local_to_global(kv, kr); + + // + // Store keyvals to their new locations + // + rs_store(kv, kr); + } +} + +// +// +// diff --git a/src/amd/vulkan/radix_sort/shaders/scatter_0_even.comp b/src/amd/vulkan/radix_sort/shaders/scatter_0_even.comp new file mode 100644 index 00000000000..19b6d93189f --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/scatter_0_even.comp @@ -0,0 +1,36 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#version 460 + +// +// +// + +// clang-format off +#define RS_SCATTER_KEYVAL_DWORD_BASE 0 + +#define RS_PARTITION_STATUS_INVALID RS_PARTITION_STATUS_EVEN_INVALID +#define RS_PARTITION_STATUS_REDUCTION RS_PARTITION_STATUS_EVEN_REDUCTION +#define RS_PARTITION_STATUS_PREFIX RS_PARTITION_STATUS_EVEN_PREFIX + +#define RS_DEVADDR_KEYVALS_IN(push_) push_.devaddr_keyvals_even +#define RS_DEVADDR_KEYVALS_OUT(push_) push_.devaddr_keyvals_odd +// clang-format on + +// +// +// + +#extension GL_GOOGLE_include_directive : require + +// +// +// + +#include "scatter.glsl" + +// +// +// diff --git a/src/amd/vulkan/radix_sort/shaders/scatter_0_odd.comp b/src/amd/vulkan/radix_sort/shaders/scatter_0_odd.comp new file mode 100644 index 00000000000..20082cc54fc --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/scatter_0_odd.comp @@ -0,0 +1,36 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#version 460 + +// +// +// + +// clang-format off +#define RS_SCATTER_KEYVAL_DWORD_BASE 0 + +#define RS_PARTITION_STATUS_INVALID RS_PARTITION_STATUS_ODD_INVALID +#define RS_PARTITION_STATUS_REDUCTION RS_PARTITION_STATUS_ODD_REDUCTION +#define RS_PARTITION_STATUS_PREFIX RS_PARTITION_STATUS_ODD_PREFIX + +#define RS_DEVADDR_KEYVALS_IN(push_) push_.devaddr_keyvals_odd +#define RS_DEVADDR_KEYVALS_OUT(push_) push_.devaddr_keyvals_even +// clang-format on + +// +// +// + +#extension GL_GOOGLE_include_directive : require + +// +// +// + +#include "scatter.glsl" + +// +// +// diff --git a/src/amd/vulkan/radix_sort/shaders/scatter_1_even.comp b/src/amd/vulkan/radix_sort/shaders/scatter_1_even.comp new file mode 100644 index 00000000000..9ca7f10b6c2 --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/scatter_1_even.comp @@ -0,0 +1,36 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#version 460 + +// +// +// + +// clang-format off +#define RS_SCATTER_KEYVAL_DWORD_BASE 1 + +#define RS_PARTITION_STATUS_INVALID RS_PARTITION_STATUS_EVEN_INVALID +#define RS_PARTITION_STATUS_REDUCTION RS_PARTITION_STATUS_EVEN_REDUCTION +#define RS_PARTITION_STATUS_PREFIX RS_PARTITION_STATUS_EVEN_PREFIX + +#define RS_DEVADDR_KEYVALS_IN(push_) push_.devaddr_keyvals_even +#define RS_DEVADDR_KEYVALS_OUT(push_) push_.devaddr_keyvals_odd +// clang-format on + +// +// +// + +#extension GL_GOOGLE_include_directive : require + +// +// +// + +#include "scatter.glsl" + +// +// +// diff --git a/src/amd/vulkan/radix_sort/shaders/scatter_1_odd.comp b/src/amd/vulkan/radix_sort/shaders/scatter_1_odd.comp new file mode 100644 index 00000000000..c5fecfc1b58 --- /dev/null +++ b/src/amd/vulkan/radix_sort/shaders/scatter_1_odd.comp @@ -0,0 +1,36 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#version 460 + +// +// +// + +// clang-format off +#define RS_SCATTER_KEYVAL_DWORD_BASE 1 + +#define RS_PARTITION_STATUS_INVALID RS_PARTITION_STATUS_ODD_INVALID +#define RS_PARTITION_STATUS_REDUCTION RS_PARTITION_STATUS_ODD_REDUCTION +#define RS_PARTITION_STATUS_PREFIX RS_PARTITION_STATUS_ODD_PREFIX + +#define RS_DEVADDR_KEYVALS_IN(push_) push_.devaddr_keyvals_odd +#define RS_DEVADDR_KEYVALS_OUT(push_) push_.devaddr_keyvals_even +// clang-format on + +// +// +// + +#extension GL_GOOGLE_include_directive : require + +// +// +// + +#include "scatter.glsl" + +// +// +// diff --git a/src/amd/vulkan/radix_sort/target.h b/src/amd/vulkan/radix_sort/target.h new file mode 100644 index 00000000000..2164389757d --- /dev/null +++ b/src/amd/vulkan/radix_sort/target.h @@ -0,0 +1,57 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_TARGET_H_ +#define SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_TARGET_H_ + +// +// +// + +#include + +// +// This structure packages target-specific configuration parameters. +// + +struct radix_sort_vk_target_config +{ + uint32_t keyval_dwords; + + struct + { + uint32_t workgroup_size_log2; + } init; + + struct + { + uint32_t workgroup_size_log2; + } fill; + + struct + { + uint32_t workgroup_size_log2; + uint32_t subgroup_size_log2; + uint32_t block_rows; + } histogram; + + struct + { + uint32_t workgroup_size_log2; + uint32_t subgroup_size_log2; + } prefix; + + struct + { + uint32_t workgroup_size_log2; + uint32_t subgroup_size_log2; + uint32_t block_rows; + } scatter; +}; + +// +// +// + +#endif // SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_TARGET_H_ diff --git a/src/amd/vulkan/radix_sort/targets/u64/config.h b/src/amd/vulkan/radix_sort/targets/u64/config.h new file mode 100644 index 00000000000..fa1a51eb017 --- /dev/null +++ b/src/amd/vulkan/radix_sort/targets/u64/config.h @@ -0,0 +1,34 @@ +// Copyright 2021 The Fuchsia Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_TARGETS_VENDORS_AMD_GCN3_U64_CONFIG_H_ +#define SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_TARGETS_VENDORS_AMD_GCN3_U64_CONFIG_H_ + +// +// +// + +// clang-format off +#define RS_KEYVAL_DWORDS 2 + +#define RS_FILL_WORKGROUP_SIZE_LOG2 7 +#define RS_FILL_BLOCK_ROWS 8 + +#define RS_HISTOGRAM_WORKGROUP_SIZE_LOG2 8 +#define RS_HISTOGRAM_SUBGROUP_SIZE_LOG2 6 +#define RS_HISTOGRAM_BLOCK_ROWS 14 + +#define RS_PREFIX_WORKGROUP_SIZE_LOG2 8 +#define RS_PREFIX_SUBGROUP_SIZE_LOG2 6 + +#define RS_SCATTER_WORKGROUP_SIZE_LOG2 8 +#define RS_SCATTER_SUBGROUP_SIZE_LOG2 6 +#define RS_SCATTER_BLOCK_ROWS 14 +// clang-format on + +// +// +// + +#endif // SRC_GRAPHICS_LIB_COMPUTE_RADIX_SORT_PLATFORMS_VK_TARGETS_VENDORS_AMD_GCN3_U64_CONFIG_H_