spirv: Handle arbitrary bit sizes for deref array indices

We already had code in link_as_ssa to handle bit sizes; we just need to
use it.  While we're at it we clean up link_as_ssa a bit and add an
explicit bit_size parameter in preparation for a day when we have derefs
that aren't 32 bit.

Cc: mesa-stable@lists.freedesktop.org
Reviewed-by: Alejandro Piñeiro <apinheiro@igalia.com>
Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
This commit is contained in:
Jason Ekstrand
2018-12-14 11:06:07 -06:00
committed by Jason Ekstrand
parent bfe31c5e46
commit abfe674c54
2 changed files with 42 additions and 34 deletions

View File

@@ -390,7 +390,7 @@ enum vtn_access_mode {
struct vtn_access_link {
enum vtn_access_mode mode;
uint32_t id;
int64_t id;
};
struct vtn_access_chain {

View File

@@ -65,6 +65,23 @@ vtn_pointer_is_external_block(struct vtn_builder *b,
b->options->lower_workgroup_access_to_offsets);
}
static nir_ssa_def *
vtn_access_link_as_ssa(struct vtn_builder *b, struct vtn_access_link link,
unsigned stride, unsigned bit_size)
{
vtn_assert(stride > 0);
if (link.mode == vtn_access_mode_literal) {
return nir_imm_intN_t(&b->nb, link.id * stride, bit_size);
} else {
nir_ssa_def *ssa = vtn_ssa_value(b, link.id)->def;
if (ssa->bit_size != bit_size)
ssa = nir_i2i(&b->nb, ssa, bit_size);
if (stride != 1)
ssa = nir_imul_imm(&b->nb, ssa, stride);
return ssa;
}
}
/* Dereference the given base pointer by the access chain */
static struct vtn_pointer *
vtn_nir_deref_pointer_dereference(struct vtn_builder *b,
@@ -95,13 +112,8 @@ vtn_nir_deref_pointer_dereference(struct vtn_builder *b,
tail = nir_build_deref_struct(&b->nb, tail, idx);
type = type->members[idx];
} else {
nir_ssa_def *index;
if (deref_chain->link[i].mode == vtn_access_mode_literal) {
index = nir_imm_int(&b->nb, deref_chain->link[i].id);
} else {
vtn_assert(deref_chain->link[i].mode == vtn_access_mode_id);
index = vtn_ssa_value(b, deref_chain->link[i].id)->def;
}
nir_ssa_def *index = vtn_access_link_as_ssa(b, deref_chain->link[i], 1,
tail->dest.ssa.bit_size);
tail = nir_build_deref_array(&b->nb, tail, index);
type = type->array_element;
}
@@ -119,26 +131,6 @@ vtn_nir_deref_pointer_dereference(struct vtn_builder *b,
return ptr;
}
static nir_ssa_def *
vtn_access_link_as_ssa(struct vtn_builder *b, struct vtn_access_link link,
unsigned stride)
{
vtn_assert(stride > 0);
if (link.mode == vtn_access_mode_literal) {
return nir_imm_int(&b->nb, link.id * stride);
} else if (stride == 1) {
nir_ssa_def *ssa = vtn_ssa_value(b, link.id)->def;
if (ssa->bit_size != 32)
ssa = nir_i2i32(&b->nb, ssa);
return ssa;
} else {
nir_ssa_def *src0 = vtn_ssa_value(b, link.id)->def;
if (src0->bit_size != 32)
src0 = nir_i2i32(&b->nb, src0);
return nir_imul_imm(&b->nb, src0, stride);
}
}
static nir_ssa_def *
vtn_variable_resource_index(struct vtn_builder *b, struct vtn_variable *var,
nir_ssa_def *desc_array_index)
@@ -196,7 +188,7 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b,
if (glsl_type_is_array(type->type)) {
if (deref_chain->length >= 1) {
desc_arr_idx =
vtn_access_link_as_ssa(b, deref_chain->link[0], 1);
vtn_access_link_as_ssa(b, deref_chain->link[0], 1, 32);
idx++;
/* This consumes a level of type */
type = type->array_element;
@@ -212,7 +204,7 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b,
} else if (deref_chain->ptr_as_array) {
/* You can't have a zero-length OpPtrAccessChain */
vtn_assert(deref_chain->length >= 1);
desc_arr_idx = vtn_access_link_as_ssa(b, deref_chain->link[0], 1);
desc_arr_idx = vtn_access_link_as_ssa(b, deref_chain->link[0], 1, 32);
} else {
/* We have a regular non-array SSBO. */
desc_arr_idx = NULL;
@@ -244,7 +236,7 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b,
*/
vtn_assert(deref_chain->length >= 1);
nir_ssa_def *offset_index =
vtn_access_link_as_ssa(b, deref_chain->link[0], 1);
vtn_access_link_as_ssa(b, deref_chain->link[0], 1, 32);
idx++;
block_index = vtn_resource_reindex(b, block_index, offset_index);
@@ -298,7 +290,7 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b,
nir_ssa_def *elem_offset =
vtn_access_link_as_ssa(b, deref_chain->link[idx],
base->ptr_type->stride);
base->ptr_type->stride, offset->bit_size);
offset = nir_iadd(&b->nb, offset, elem_offset);
idx++;
}
@@ -319,7 +311,8 @@ vtn_ssa_offset_pointer_dereference(struct vtn_builder *b,
case GLSL_TYPE_BOOL:
case GLSL_TYPE_ARRAY: {
nir_ssa_def *elem_offset =
vtn_access_link_as_ssa(b, deref_chain->link[idx], type->stride);
vtn_access_link_as_ssa(b, deref_chain->link[idx],
type->stride, offset->bit_size);
offset = nir_iadd(&b->nb, offset, elem_offset);
type = type->array_element;
access |= type->access;
@@ -1911,7 +1904,22 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
struct vtn_value *link_val = vtn_untyped_value(b, w[i]);
if (link_val->value_type == vtn_value_type_constant) {
chain->link[idx].mode = vtn_access_mode_literal;
chain->link[idx].id = link_val->constant->values[0].u32[0];
switch (glsl_get_bit_size(link_val->type->type)) {
case 8:
chain->link[idx].id = link_val->constant->values[0].i8[0];
break;
case 16:
chain->link[idx].id = link_val->constant->values[0].i16[0];
break;
case 32:
chain->link[idx].id = link_val->constant->values[0].i32[0];
break;
case 64:
chain->link[idx].id = link_val->constant->values[0].i64[0];
break;
default:
vtn_fail("Invalid bit size");
}
} else {
chain->link[idx].mode = vtn_access_mode_id;
chain->link[idx].id = w[i];