microsoft/clc: Add API to independently specialize SPIR-V

We need the ability to specialize unlinked SPIR-V, so use SPIR-V tools
to specialize prior to linking.

Acked-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/10322>
This commit is contained in:
Jesse Natalie
2021-04-19 07:14:24 -07:00
committed by Marge Bot
parent c50bbf1f28
commit 068c6b5a37
5 changed files with 95 additions and 2 deletions

View File

@@ -661,6 +661,21 @@ void clc_free_parsed_spirv(struct clc_parsed_spirv *data)
clc_free_kernels_info(data->kernels, data->num_kernels);
}
bool
clc_specialize_spirv(const struct clc_binary *in_spirv,
const struct clc_parsed_spirv *parsed_data,
const struct clc_spirv_specialization_consts *consts,
struct clc_binary *out_spirv)
{
if (!clc_spirv_specialize(in_spirv, parsed_data, consts, out_spirv))
return false;
if (debug_get_option_debug_clc() & CLC_DEBUG_DUMP_SPIRV)
clc_dump_spirv(out_spirv, stdout);
return true;
}
static nir_variable *
add_kernel_inputs_var(struct clc_dxil_object *dxil, nir_shader *nir,
unsigned *cbv_id)

View File

@@ -302,6 +302,13 @@ struct clc_spirv_specialization_consts {
const struct clc_spirv_specialization *specializations;
unsigned num_specializations;
};
bool
clc_specialize_spirv(const struct clc_binary *in_spirv,
const struct clc_parsed_spirv *parsed_data,
const struct clc_spirv_specialization_consts *consts,
struct clc_binary *out_spirv);
bool
clc_spirv_to_dxil(struct clc_libclc *lib,
const struct clc_binary *linked_spirv,

View File

@@ -46,6 +46,7 @@
#include <spirv-tools/libspirv.hpp>
#include <spirv-tools/linker.hpp>
#include <spirv-tools/optimizer.hpp>
#include "util/macros.h"
#include "glsl_types.h"
@@ -58,6 +59,8 @@
#include "opencl-c.h.h"
#include "opencl-c-base.h.h"
constexpr spv_target_env spirv_target = SPV_ENV_UNIVERSAL_1_0;
using ::llvm::Function;
using ::llvm::LLVMContext;
using ::llvm::Module;
@@ -955,7 +958,7 @@ clc_link_spirv_binaries(const struct clc_linker_args *args,
}
SPIRVMessageConsumer msgconsumer(logger);
spvtools::Context context(SPV_ENV_UNIVERSAL_1_0);
spvtools::Context context(spirv_target);
context.SetMessageConsumer(msgconsumer);
spvtools::LinkerOptions options;
options.SetAllowPartialLinkage(args->create_library);
@@ -973,10 +976,71 @@ clc_link_spirv_binaries(const struct clc_linker_args *args,
return 0;
}
int
clc_spirv_specialize(const struct clc_binary *in_spirv,
const struct clc_parsed_spirv *parsed_data,
const struct clc_spirv_specialization_consts *consts,
struct clc_binary *out_spirv)
{
std::unordered_map<uint32_t, std::vector<uint32_t>> spec_const_map;
for (unsigned i = 0; i < consts->num_specializations; ++i) {
unsigned id = consts->specializations[i].id;
auto parsed_spec_const = std::find_if(parsed_data->spec_constants,
parsed_data->spec_constants + parsed_data->num_spec_constants,
[id](const clc_parsed_spec_constant &c) { return c.id == id; });
assert(parsed_spec_const != parsed_data->spec_constants + parsed_data->num_spec_constants);
std::vector<uint32_t> words;
switch (parsed_spec_const->type) {
case CLC_SPEC_CONSTANT_BOOL:
words.push_back(consts->specializations[i].value.b);
break;
case CLC_SPEC_CONSTANT_INT32:
case CLC_SPEC_CONSTANT_UINT32:
case CLC_SPEC_CONSTANT_FLOAT:
words.push_back(consts->specializations[i].value.u32);
break;
case CLC_SPEC_CONSTANT_INT16:
words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i16);
break;
case CLC_SPEC_CONSTANT_INT8:
words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i8);
break;
case CLC_SPEC_CONSTANT_UINT16:
words.push_back((uint32_t)consts->specializations[i].value.u16);
break;
case CLC_SPEC_CONSTANT_UINT8:
words.push_back((uint32_t)consts->specializations[i].value.u8);
break;
case CLC_SPEC_CONSTANT_DOUBLE:
case CLC_SPEC_CONSTANT_INT64:
case CLC_SPEC_CONSTANT_UINT64:
words.resize(2);
memcpy(words.data(), &consts->specializations[i].value.u64, 8);
break;
}
ASSERTED auto ret = spec_const_map.emplace(id, std::move(words));
assert(ret.second);
}
spvtools::Optimizer opt(spirv_target);
opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(std::move(spec_const_map)));
std::vector<uint32_t> result;
if (!opt.Run(static_cast<const uint32_t*>(in_spirv->data), in_spirv->size / 4, &result))
return false;
out_spirv->size = result.size() * 4;
out_spirv->data = malloc(out_spirv->size);
memcpy(out_spirv->data, result.data(), out_spirv->size);
return true;
}
void
clc_dump_spirv(const struct clc_binary *spvbin, FILE *f)
{
spvtools::SpirvTools tools(SPV_ENV_UNIVERSAL_1_0);
spvtools::SpirvTools tools(spirv_target);
const uint32_t *data = static_cast<const uint32_t *>(spvbin->data);
std::vector<uint32_t> bin(data, data + (spvbin->size / 4));
std::string out;

View File

@@ -70,6 +70,12 @@ clc_link_spirv_binaries(const struct clc_linker_args *args,
const struct clc_logger *logger,
struct clc_binary *out_spirv);
int
clc_spirv_specialize(const struct clc_binary *in_spirv,
const struct clc_parsed_spirv *parsed_data,
const struct clc_spirv_specialization_consts *consts,
struct clc_binary *out_spirv);
void
clc_dump_spirv(const struct clc_binary *spvbin, FILE *f);

View File

@@ -12,6 +12,7 @@ EXPORTS
clc_link_spirv
clc_parse_spirv
clc_free_parsed_spirv
clc_specialize_spirv
clc_spirv_to_dxil
clc_free_dxil_object
clc_compiler_get_version