spirv: Use nir_const_value for spec constants
When we originally wrote spirv_to_nir we didn't have a good scalar value union to handily use so we rolled our own thing for spec constants. Now that we have nir_const_value, we can use that and simplify a bunch of the spec constant logic. Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com> Acked-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4675>
This commit is contained in:

committed by
Marge Bot

parent
6211e79ba5
commit
f4addfdde3
@@ -320,7 +320,7 @@ radv_shader_compile_to_nir(struct radv_device *device,
|
|||||||
struct nir_spirv_specialization *spec_entries = NULL;
|
struct nir_spirv_specialization *spec_entries = NULL;
|
||||||
if (spec_info && spec_info->mapEntryCount > 0) {
|
if (spec_info && spec_info->mapEntryCount > 0) {
|
||||||
num_spec_entries = spec_info->mapEntryCount;
|
num_spec_entries = spec_info->mapEntryCount;
|
||||||
spec_entries = malloc(num_spec_entries * sizeof(*spec_entries));
|
spec_entries = calloc(num_spec_entries, sizeof(*spec_entries));
|
||||||
for (uint32_t i = 0; i < num_spec_entries; i++) {
|
for (uint32_t i = 0; i < num_spec_entries; i++) {
|
||||||
VkSpecializationMapEntry entry = spec_info->pMapEntries[i];
|
VkSpecializationMapEntry entry = spec_info->pMapEntries[i];
|
||||||
const void *data = spec_info->pData + entry.offset;
|
const void *data = spec_info->pData + entry.offset;
|
||||||
@@ -329,16 +329,16 @@ radv_shader_compile_to_nir(struct radv_device *device,
|
|||||||
spec_entries[i].id = spec_info->pMapEntries[i].constantID;
|
spec_entries[i].id = spec_info->pMapEntries[i].constantID;
|
||||||
switch (entry.size) {
|
switch (entry.size) {
|
||||||
case 8:
|
case 8:
|
||||||
spec_entries[i].data64 = *(const uint64_t *)data;
|
spec_entries[i].value.u64 = *(const uint64_t *)data;
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
spec_entries[i].data32 = *(const uint32_t *)data;
|
spec_entries[i].value.u32 = *(const uint32_t *)data;
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
spec_entries[i].data32 = *(const uint16_t *)data;
|
spec_entries[i].value.u16 = *(const uint16_t *)data;
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
spec_entries[i].data32 = *(const uint8_t *)data;
|
spec_entries[i].value.u8 = *(const uint8_t *)data;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
assert(!"Invalid spec constant size");
|
assert(!"Invalid spec constant size");
|
||||||
|
@@ -37,10 +37,7 @@ extern "C" {
|
|||||||
|
|
||||||
struct nir_spirv_specialization {
|
struct nir_spirv_specialization {
|
||||||
uint32_t id;
|
uint32_t id;
|
||||||
union {
|
nir_const_value value;
|
||||||
uint32_t data32;
|
|
||||||
uint64_t data64;
|
|
||||||
};
|
|
||||||
bool defined_on_module;
|
bool defined_on_module;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -163,14 +163,6 @@ _vtn_fail(struct vtn_builder *b, const char *file, unsigned line,
|
|||||||
longjmp(b->fail_jump, 1);
|
longjmp(b->fail_jump, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct spec_constant_value {
|
|
||||||
bool is_double;
|
|
||||||
union {
|
|
||||||
uint32_t data32;
|
|
||||||
uint64_t data64;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
static struct vtn_ssa_value *
|
static struct vtn_ssa_value *
|
||||||
vtn_undef_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
|
vtn_undef_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
|
||||||
{
|
{
|
||||||
@@ -1547,41 +1539,15 @@ spec_constant_decoration_cb(struct vtn_builder *b, UNUSED struct vtn_value *val,
|
|||||||
if (dec->decoration != SpvDecorationSpecId)
|
if (dec->decoration != SpvDecorationSpecId)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
struct spec_constant_value *const_value = data;
|
nir_const_value *value = data;
|
||||||
|
|
||||||
for (unsigned i = 0; i < b->num_specializations; i++) {
|
for (unsigned i = 0; i < b->num_specializations; i++) {
|
||||||
if (b->specializations[i].id == dec->operands[0]) {
|
if (b->specializations[i].id == dec->operands[0]) {
|
||||||
if (const_value->is_double)
|
*value = b->specializations[i].value;
|
||||||
const_value->data64 = b->specializations[i].data64;
|
|
||||||
else
|
|
||||||
const_value->data32 = b->specializations[i].data32;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static uint32_t
|
|
||||||
get_specialization(struct vtn_builder *b, struct vtn_value *val,
|
|
||||||
uint32_t const_value)
|
|
||||||
{
|
|
||||||
struct spec_constant_value data;
|
|
||||||
data.is_double = false;
|
|
||||||
data.data32 = const_value;
|
|
||||||
vtn_foreach_decoration(b, val, spec_constant_decoration_cb, &data);
|
|
||||||
return data.data32;
|
|
||||||
}
|
|
||||||
|
|
||||||
static uint64_t
|
|
||||||
get_specialization64(struct vtn_builder *b, struct vtn_value *val,
|
|
||||||
uint64_t const_value)
|
|
||||||
{
|
|
||||||
struct spec_constant_value data;
|
|
||||||
data.is_double = true;
|
|
||||||
data.data64 = const_value;
|
|
||||||
vtn_foreach_decoration(b, val, spec_constant_decoration_cb, &data);
|
|
||||||
return data.data64;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void
|
static void
|
||||||
handle_workgroup_size_decoration_cb(struct vtn_builder *b,
|
handle_workgroup_size_decoration_cb(struct vtn_builder *b,
|
||||||
struct vtn_value *val,
|
struct vtn_value *val,
|
||||||
@@ -1613,18 +1579,21 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||||||
"Result type of %s must be OpTypeBool",
|
"Result type of %s must be OpTypeBool",
|
||||||
spirv_op_to_string(opcode));
|
spirv_op_to_string(opcode));
|
||||||
|
|
||||||
uint32_t int_val = (opcode == SpvOpConstantTrue ||
|
bool bval = (opcode == SpvOpConstantTrue ||
|
||||||
opcode == SpvOpSpecConstantTrue);
|
opcode == SpvOpSpecConstantTrue);
|
||||||
|
|
||||||
|
nir_const_value u32val = nir_const_value_for_uint(bval, 32);
|
||||||
|
|
||||||
if (opcode == SpvOpSpecConstantTrue ||
|
if (opcode == SpvOpSpecConstantTrue ||
|
||||||
opcode == SpvOpSpecConstantFalse)
|
opcode == SpvOpSpecConstantFalse)
|
||||||
int_val = get_specialization(b, val, int_val);
|
vtn_foreach_decoration(b, val, spec_constant_decoration_cb, &u32val);
|
||||||
|
|
||||||
val->constant->values[0].b = int_val != 0;
|
val->constant->values[0].b = u32val.u32 != 0;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case SpvOpConstant: {
|
case SpvOpConstant:
|
||||||
|
case SpvOpSpecConstant: {
|
||||||
vtn_fail_if(val->type->base_type != vtn_base_type_scalar,
|
vtn_fail_if(val->type->base_type != vtn_base_type_scalar,
|
||||||
"Result type of %s must be a scalar",
|
"Result type of %s must be a scalar",
|
||||||
spirv_op_to_string(opcode));
|
spirv_op_to_string(opcode));
|
||||||
@@ -1645,31 +1614,10 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||||||
default:
|
default:
|
||||||
vtn_fail("Unsupported SpvOpConstant bit size: %u", bit_size);
|
vtn_fail("Unsupported SpvOpConstant bit size: %u", bit_size);
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
case SpvOpSpecConstant: {
|
if (opcode == SpvOpSpecConstant)
|
||||||
vtn_fail_if(val->type->base_type != vtn_base_type_scalar,
|
vtn_foreach_decoration(b, val, spec_constant_decoration_cb,
|
||||||
"Result type of %s must be a scalar",
|
&val->constant->values[0]);
|
||||||
spirv_op_to_string(opcode));
|
|
||||||
int bit_size = glsl_get_bit_size(val->type->type);
|
|
||||||
switch (bit_size) {
|
|
||||||
case 64:
|
|
||||||
val->constant->values[0].u64 =
|
|
||||||
get_specialization64(b, val, vtn_u64_literal(&w[3]));
|
|
||||||
break;
|
|
||||||
case 32:
|
|
||||||
val->constant->values[0].u32 = get_specialization(b, val, w[3]);
|
|
||||||
break;
|
|
||||||
case 16:
|
|
||||||
val->constant->values[0].u16 = get_specialization(b, val, w[3]);
|
|
||||||
break;
|
|
||||||
case 8:
|
|
||||||
val->constant->values[0].u8 = get_specialization(b, val, w[3]);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
vtn_fail("Unsupported SpvOpSpecConstant bit size");
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1719,7 +1667,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||||||
}
|
}
|
||||||
|
|
||||||
case SpvOpSpecConstantOp: {
|
case SpvOpSpecConstantOp: {
|
||||||
SpvOp opcode = get_specialization(b, val, w[3]);
|
nir_const_value u32op = nir_const_value_for_uint(w[3], 32);
|
||||||
|
vtn_foreach_decoration(b, val, spec_constant_decoration_cb, &u32op);
|
||||||
|
SpvOp opcode = u32op.u32;
|
||||||
switch (opcode) {
|
switch (opcode) {
|
||||||
case SpvOpVectorShuffle: {
|
case SpvOpVectorShuffle: {
|
||||||
struct vtn_value *v0 = &b->values[w[4]];
|
struct vtn_value *v0 = &b->values[w[4]];
|
||||||
|
@@ -53,7 +53,7 @@ tu_spirv_to_nir(struct ir3_compiler *compiler,
|
|||||||
struct nir_spirv_specialization *spec = NULL;
|
struct nir_spirv_specialization *spec = NULL;
|
||||||
uint32_t num_spec = 0;
|
uint32_t num_spec = 0;
|
||||||
if (spec_info && spec_info->mapEntryCount) {
|
if (spec_info && spec_info->mapEntryCount) {
|
||||||
spec = malloc(sizeof(*spec) * spec_info->mapEntryCount);
|
spec = calloc(spec_info->mapEntryCount, sizeof(*spec));
|
||||||
if (!spec)
|
if (!spec)
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
@@ -64,16 +64,16 @@ tu_spirv_to_nir(struct ir3_compiler *compiler,
|
|||||||
spec[i].id = entry->constantID;
|
spec[i].id = entry->constantID;
|
||||||
switch (entry->size) {
|
switch (entry->size) {
|
||||||
case 8:
|
case 8:
|
||||||
spec[i].data64 = *(const uint64_t *)data;
|
spec[i].value.u64 = *(const uint64_t *)data;
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
spec[i].data32 = *(const uint32_t *)data;
|
spec[i].value.u32 = *(const uint32_t *)data;
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
spec[i].data32 = *(const uint16_t *)data;
|
spec[i].value.u16 = *(const uint16_t *)data;
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
spec[i].data32 = *(const uint8_t *)data;
|
spec[i].value.u8 = *(const uint8_t *)data;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
assert(!"Invalid spec constant size");
|
assert(!"Invalid spec constant size");
|
||||||
|
@@ -140,7 +140,7 @@ anv_shader_compile_to_nir(struct anv_device *device,
|
|||||||
struct nir_spirv_specialization *spec_entries = NULL;
|
struct nir_spirv_specialization *spec_entries = NULL;
|
||||||
if (spec_info && spec_info->mapEntryCount > 0) {
|
if (spec_info && spec_info->mapEntryCount > 0) {
|
||||||
num_spec_entries = spec_info->mapEntryCount;
|
num_spec_entries = spec_info->mapEntryCount;
|
||||||
spec_entries = malloc(num_spec_entries * sizeof(*spec_entries));
|
spec_entries = calloc(num_spec_entries, sizeof(*spec_entries));
|
||||||
for (uint32_t i = 0; i < num_spec_entries; i++) {
|
for (uint32_t i = 0; i < num_spec_entries; i++) {
|
||||||
VkSpecializationMapEntry entry = spec_info->pMapEntries[i];
|
VkSpecializationMapEntry entry = spec_info->pMapEntries[i];
|
||||||
const void *data = spec_info->pData + entry.offset;
|
const void *data = spec_info->pData + entry.offset;
|
||||||
@@ -149,16 +149,16 @@ anv_shader_compile_to_nir(struct anv_device *device,
|
|||||||
spec_entries[i].id = spec_info->pMapEntries[i].constantID;
|
spec_entries[i].id = spec_info->pMapEntries[i].constantID;
|
||||||
switch (entry.size) {
|
switch (entry.size) {
|
||||||
case 8:
|
case 8:
|
||||||
spec_entries[i].data64 = *(const uint64_t *)data;
|
spec_entries[i].value.u64 = *(const uint64_t *)data;
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
spec_entries[i].data32 = *(const uint32_t *)data;
|
spec_entries[i].value.u32 = *(const uint32_t *)data;
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
spec_entries[i].data32 = *(const uint16_t *)data;
|
spec_entries[i].value.u16 = *(const uint16_t *)data;
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
spec_entries[i].data32 = *(const uint8_t *)data;
|
spec_entries[i].value.u8 = *(const uint8_t *)data;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
assert(!"Invalid spec constant size");
|
assert(!"Invalid spec constant size");
|
||||||
|
@@ -239,7 +239,7 @@ _mesa_spirv_to_nir(struct gl_context *ctx,
|
|||||||
|
|
||||||
for (unsigned i = 0; i < spirv_data->NumSpecializationConstants; ++i) {
|
for (unsigned i = 0; i < spirv_data->NumSpecializationConstants; ++i) {
|
||||||
spec_entries[i].id = spirv_data->SpecializationConstantsIndex[i];
|
spec_entries[i].id = spirv_data->SpecializationConstantsIndex[i];
|
||||||
spec_entries[i].data32 = spirv_data->SpecializationConstantsValue[i];
|
spec_entries[i].value.u32 = spirv_data->SpecializationConstantsValue[i];
|
||||||
spec_entries[i].defined_on_module = false;
|
spec_entries[i].defined_on_module = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,7 +370,7 @@ _mesa_SpecializeShaderARB(GLuint shader,
|
|||||||
|
|
||||||
for (unsigned i = 0; i < numSpecializationConstants; ++i) {
|
for (unsigned i = 0; i < numSpecializationConstants; ++i) {
|
||||||
spec_entries[i].id = pConstantIndex[i];
|
spec_entries[i].id = pConstantIndex[i];
|
||||||
spec_entries[i].data32 = pConstantValue[i];
|
spec_entries[i].value.u32 = pConstantValue[i];
|
||||||
spec_entries[i].defined_on_module = false;
|
spec_entries[i].defined_on_module = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user