nir/spirv: Rework access chains a bit to allow for literals

This makes them much easier to construct because you can also just specify
a literal number and it doesn't have to be a valid SPIR-V id.
This commit is contained in:
Jason Ekstrand
2016-01-21 10:20:50 -08:00
parent 5d9a6fd526
commit a8af0f536c
2 changed files with 75 additions and 37 deletions

View File

@@ -236,13 +236,23 @@ struct vtn_type {
struct vtn_variable;
enum vtn_access_mode {
vtn_access_mode_id,
vtn_access_mode_literal,
};
struct vtn_access_link {
enum vtn_access_mode mode;
uint32_t id;
};
struct vtn_access_chain {
struct vtn_variable *var;
uint32_t length;
/* Struct elements and array offsets */
uint32_t ids[0];
struct vtn_access_link link[0];
};
enum vtn_variable_mode {

View File

@@ -27,6 +27,39 @@
#include "vtn_private.h"
static struct vtn_access_chain *
vtn_access_chain_extend(struct vtn_builder *b, struct vtn_access_chain *old,
unsigned new_ids)
{
struct vtn_access_chain *chain;
unsigned new_len = old->length + new_ids;
chain = ralloc_size(b, sizeof(*chain) + new_len * sizeof(chain->link[0]));
chain->var = old->var;
chain->length = new_len;
for (unsigned i = 0; i < old->length; i++)
chain->link[i] = old->link[i];
return chain;
}
static nir_ssa_def *
vtn_access_link_as_ssa(struct vtn_builder *b, struct vtn_access_link link,
unsigned stride)
{
assert(stride > 0);
if (link.mode == vtn_access_mode_literal) {
return nir_imm_int(&b->nb, link.id * stride);
} else if (stride == 1) {
return vtn_ssa_value(b, link.id)->def;
} else {
return nir_imul(&b->nb, vtn_ssa_value(b, link.id)->def,
nir_imm_int(&b->nb, stride));
}
}
/* Crawls a chain of array derefs and rewrites the types so that the
* lengths stay the same but the terminal type is the one given by
* tail_type. This is useful for split structures.
@@ -60,7 +93,6 @@ vtn_access_chain_to_deref(struct vtn_builder *b, struct vtn_access_chain *chain)
nir_variable **members = chain->var->members;
for (unsigned i = 0; i < chain->length; i++) {
struct vtn_value *idx_val = vtn_untyped_value(b, chain->ids[i]);
enum glsl_base_type base_type = glsl_get_base_type(deref_type->type);
switch (base_type) {
case GLSL_TYPE_UINT:
@@ -81,15 +113,15 @@ vtn_access_chain_to_deref(struct vtn_builder *b, struct vtn_access_chain *chain)
deref_arr->deref.type = deref_type->type;
if (idx_val->value_type == vtn_value_type_constant) {
if (chain->link[i].mode == vtn_access_mode_literal) {
deref_arr->deref_array_type = nir_deref_array_type_direct;
deref_arr->base_offset = idx_val->constant->value.u[0];
deref_arr->base_offset = chain->link[i].id;
} else {
assert(idx_val->value_type == vtn_value_type_ssa);
assert(glsl_type_is_scalar(idx_val->ssa->type));
assert(chain->link[i].mode == vtn_access_mode_id);
deref_arr->deref_array_type = nir_deref_array_type_indirect;
deref_arr->base_offset = 0;
deref_arr->indirect = nir_src_for_ssa(idx_val->ssa->def);
deref_arr->indirect =
nir_src_for_ssa(vtn_ssa_value(b, chain->link[i].id)->def);
}
tail->child = &deref_arr->deref;
tail = tail->child;
@@ -97,8 +129,8 @@ vtn_access_chain_to_deref(struct vtn_builder *b, struct vtn_access_chain *chain)
}
case GLSL_TYPE_STRUCT: {
assert(idx_val->value_type == vtn_value_type_constant);
unsigned idx = idx_val->constant->value.u[0];
assert(chain->link[i].mode == vtn_access_mode_literal);
unsigned idx = chain->link[i].id;
deref_type = deref_type->members[idx];
if (members) {
/* This is a pre-split structure. */
@@ -265,7 +297,7 @@ get_vulkan_resource_index(struct vtn_builder *b, struct vtn_access_chain *chain,
nir_ssa_def *array_index;
if (glsl_type_is_array(chain->var->type->type)) {
assert(chain->length > 0);
array_index = vtn_ssa_value(b, chain->ids[0])->def;
array_index = vtn_access_link_as_ssa(b, chain->link[0], 1);
*chain_idx = 1;
*type = chain->var->type->array_element;
} else {
@@ -315,9 +347,8 @@ vtn_access_chain_to_offset(struct vtn_builder *b,
case GLSL_TYPE_ARRAY:
offset = nir_iadd(&b->nb, offset,
nir_imul(&b->nb,
vtn_ssa_value(b, chain->ids[idx])->def,
nir_imm_int(&b->nb, type->stride)));
vtn_access_link_as_ssa(b, chain->link[idx],
type->stride));
if (glsl_type_is_vector(type->type)) {
/* This had better be the tail */
@@ -330,10 +361,8 @@ vtn_access_chain_to_offset(struct vtn_builder *b,
break;
case GLSL_TYPE_STRUCT: {
struct vtn_value *member_val =
vtn_value(b, chain->ids[idx], vtn_value_type_constant);
unsigned member = member_val->constant->value.u[0];
assert(chain->link[idx].mode == vtn_access_mode_literal);
unsigned member = chain->link[idx].id;
offset = nir_iadd(&b->nb, offset,
nir_imm_int(&b->nb, type->offsets[member]));
type = type->members[member];
@@ -448,16 +477,15 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
} else if (type->row_major) {
/* Row-major but with an access chiain. */
nir_ssa_def *col_offset =
nir_imul(&b->nb, vtn_ssa_value(b, chain->ids[chain_idx])->def,
nir_imm_int(&b->nb, type->array_element->stride));
vtn_access_link_as_ssa(b, chain->link[chain_idx],
type->array_element->stride);
offset = nir_iadd(&b->nb, offset, col_offset);
if (chain_idx + 1 < chain->length) {
/* Picking off a single element */
nir_ssa_def *row_offset =
nir_imul(&b->nb,
vtn_ssa_value(b, chain->ids[chain_idx + 1])->def,
nir_imm_int(&b->nb, type->stride));
vtn_access_link_as_ssa(b, chain->link[chain_idx + 1],
type->stride);
offset = nir_iadd(&b->nb, offset, row_offset);
_vtn_load_store_tail(b, op, load, index, offset, inout,
glsl_scalar_type(base_type));
@@ -487,8 +515,7 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
} else {
/* Column-major with a deref. Fall through to array case. */
nir_ssa_def *col_offset =
nir_imul(&b->nb, vtn_ssa_value(b, chain->ids[chain_idx])->def,
nir_imm_int(&b->nb, type->stride));
vtn_access_link_as_ssa(b, chain->link[chain_idx], type->stride);
offset = nir_iadd(&b->nb, offset, col_offset);
_vtn_block_load_store(b, op, load, index, offset,
@@ -502,8 +529,7 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
} else {
/* Single component of a vector. Fall through to array case. */
nir_ssa_def *elem_offset =
nir_imul(&b->nb, vtn_ssa_value(b, chain->ids[chain_idx])->def,
nir_imm_int(&b->nb, type->stride));
vtn_access_link_as_ssa(b, chain->link[chain_idx], type->stride);
offset = nir_iadd(&b->nb, offset, elem_offset);
_vtn_block_load_store(b, op, load, index, offset, NULL, 0,
@@ -1158,18 +1184,20 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
base = base_val->access_chain;
}
uint32_t new_len = base->length + count - 4;
chain = ralloc_size(b, sizeof(*chain) + new_len * sizeof(chain->ids[0]));
chain = vtn_access_chain_extend(b, base, count - 4);
*chain = *base;
chain->length = new_len;
unsigned idx = 0;
for (int i = 0; i < base->length; i++)
chain->ids[idx++] = base->ids[i];
for (int i = 4; i < count; i++)
chain->ids[idx++] = w[i];
unsigned idx = base->length;
for (int i = 4; i < count; i++) {
struct vtn_value *link_val = vtn_untyped_value(b, w[i]);
if (link_val->value_type == vtn_value_type_constant) {
chain->link[idx].mode = vtn_access_mode_literal;
chain->link[idx].id = link_val->constant->value.u[0];
} else {
chain->link[idx].mode = vtn_access_mode_id;
chain->link[idx].id = w[i];
}
idx++;
}
if (base_val->value_type == vtn_value_type_sampled_image) {
struct vtn_value *val =