nir/spirv: Rework the way pointers get dereferenced

This has the advantage of moving all of the "extend an access chain"
code into one place.

Reviewed-by: Connor Abbott <cwabbott0@gmail.com>
This commit is contained in:
Jason Ekstrand
2017-06-29 10:33:33 -07:00
committed by Jason Ekstrand
parent 4c21e6b7f8
commit 604eda3712
2 changed files with 94 additions and 70 deletions

View File

@@ -255,8 +255,12 @@ struct vtn_access_link {
struct vtn_access_chain {
uint32_t length;
/* Struct elements and array offsets */
struct vtn_access_link link[0];
/** Struct elements and array offsets.
*
* This is an array of 1 so that it can conveniently be created on the
* stack but the real length is given by the length field.
*/
struct vtn_access_link link[1];
};
enum vtn_variable_mode {

View File

@@ -28,6 +28,20 @@
#include "vtn_private.h"
#include "spirv_info.h"
static struct vtn_access_chain *
vtn_access_chain_create(struct vtn_builder *b, unsigned length)
{
struct vtn_access_chain *chain;
/* Subtract 1 from the length since there's already one built in */
size_t size = sizeof(*chain) +
(MAX2(length, 1) - 1) * sizeof(chain->link[0]);
chain = rzalloc_size(b, size);
chain->length = length;
return chain;
}
static struct vtn_access_chain *
vtn_access_chain_extend(struct vtn_builder *b, struct vtn_access_chain *old,
unsigned new_ids)
@@ -35,11 +49,7 @@ vtn_access_chain_extend(struct vtn_builder *b, struct vtn_access_chain *old,
struct vtn_access_chain *chain;
unsigned old_len = old ? old->length : 0;
unsigned new_len = old_len + new_ids;
/* TODO: don't use rzalloc */
chain = rzalloc_size(b, sizeof(*chain) + new_len * sizeof(chain->link[0]));
chain->length = new_len;
chain = vtn_access_chain_create(b, old_len + new_ids);
for (unsigned i = 0; i < old_len; i++)
chain->link[i] = old->link[i];
@@ -47,6 +57,37 @@ vtn_access_chain_extend(struct vtn_builder *b, struct vtn_access_chain *old,
return chain;
}
/* Dereference the given base pointer by the access chain */
static struct vtn_pointer *
vtn_pointer_dereference(struct vtn_builder *b,
struct vtn_pointer *base,
struct vtn_access_chain *deref_chain)
{
struct vtn_access_chain *chain =
vtn_access_chain_extend(b, base->chain, deref_chain->length);
struct vtn_type *type = base->type;
unsigned start = base->chain ? base->chain->length : 0;
for (unsigned i = 0; i < deref_chain->length; i++) {
chain->link[start + i] = deref_chain->link[i];
if (glsl_type_is_struct(type->type)) {
assert(deref_chain->link[i].mode == vtn_access_mode_literal);
type = type->members[deref_chain->link[i].id];
} else {
type = type->array_element;
}
}
struct vtn_pointer *ptr = rzalloc(b, struct vtn_pointer);
ptr->mode = base->mode;
ptr->type = type;
ptr->var = base->var;
ptr->chain = chain;
return ptr;
}
static nir_ssa_def *
vtn_access_link_as_ssa(struct vtn_builder *b, struct vtn_access_link link,
unsigned stride)
@@ -730,15 +771,16 @@ _vtn_variable_load_store(struct vtn_builder *b, bool load,
(*inout)->elems = rzalloc_array(b, struct vtn_ssa_value *, elems);
}
struct vtn_pointer elem = *ptr;
elem.chain = vtn_access_chain_extend(b, ptr->chain, 1);
unsigned link_idx = ptr->chain ? ptr->chain->length : 0;
elem.chain->link[link_idx].mode = vtn_access_mode_literal;
struct vtn_access_chain chain = {
.length = 1,
.link = {
{ .mode = vtn_access_mode_literal, },
}
};
for (unsigned i = 0; i < elems; i++) {
elem.chain->link[link_idx].id = i;
elem.type = (base_type == GLSL_TYPE_ARRAY) ? ptr->type->array_element :
ptr->type->members[i];
_vtn_variable_load_store(b, load, &elem, &(*inout)->elems[i]);
chain.link[0].id = i;
struct vtn_pointer *elem = vtn_pointer_dereference(b, ptr, &chain);
_vtn_variable_load_store(b, load, elem, &(*inout)->elems[i]);
}
return;
}
@@ -797,24 +839,21 @@ _vtn_variable_copy(struct vtn_builder *b, struct vtn_pointer *dest,
case GLSL_TYPE_ARRAY:
case GLSL_TYPE_STRUCT: {
struct vtn_pointer src_elem = *src, dest_elem = *dest;
src_elem.chain = vtn_access_chain_extend(b, src->chain, 1);
dest_elem.chain = vtn_access_chain_extend(b, dest->chain, 1);
src_elem.chain->link[src_elem.chain->length - 1].mode = vtn_access_mode_literal;
dest_elem.chain->link[dest_elem.chain->length - 1].mode = vtn_access_mode_literal;
struct vtn_access_chain chain = {
.length = 1,
.link = {
{ .mode = vtn_access_mode_literal, },
}
};
unsigned elems = glsl_get_length(src->type->type);
for (unsigned i = 0; i < elems; i++) {
src_elem.chain->link[src_elem.chain->length - 1].id = i;
dest_elem.chain->link[dest_elem.chain->length - 1].id = i;
if (base_type == GLSL_TYPE_STRUCT) {
src_elem.type = src->type->members[i];
dest_elem.type = dest->type->members[i];
} else {
src_elem.type = src->type->array_element;
dest_elem.type = dest->type->array_element;
}
_vtn_variable_copy(b, &dest_elem, &src_elem);
chain.link[0].id = i;
struct vtn_pointer *src_elem =
vtn_pointer_dereference(b, src, &chain);
struct vtn_pointer *dest_elem =
vtn_pointer_dereference(b, dest, &chain);
_vtn_variable_copy(b, dest_elem, src_elem);
}
return;
}
@@ -1548,7 +1587,22 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
case SpvOpAccessChain:
case SpvOpInBoundsAccessChain: {
struct vtn_pointer *base, *ptr;
struct vtn_access_chain *chain = vtn_access_chain_create(b, count - 4);
unsigned idx = 0;
for (int i = 4; i < count; i++) {
struct vtn_value *link_val = vtn_untyped_value(b, w[i]);
if (link_val->value_type == vtn_value_type_constant) {
chain->link[idx].mode = vtn_access_mode_literal;
chain->link[idx].id = link_val->constant->values[0].u32[0];
} else {
chain->link[idx].mode = vtn_access_mode_id;
chain->link[idx].id = w[i];
}
idx++;
}
struct vtn_value *base_val = vtn_untyped_value(b, w[3]);
if (base_val->value_type == vtn_value_type_sampled_image) {
/* This is rather insane. SPIR-V allows you to use OpSampledImage
@@ -1558,51 +1612,17 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
* sampler when crawling the access chain, but it does leave us
* with this rather awkward little special-case.
*/
base = base_val->sampled_image->image;
} else {
assert(base_val->value_type == vtn_value_type_pointer);
base = base_val->pointer;
}
struct vtn_access_chain *chain =
vtn_access_chain_extend(b, base->chain, count - 4);
struct vtn_type *type = base->type;
unsigned idx = base->chain ? base->chain->length : 0;
for (int i = 4; i < count; i++) {
struct vtn_value *link_val = vtn_untyped_value(b, w[i]);
if (link_val->value_type == vtn_value_type_constant) {
chain->link[idx].mode = vtn_access_mode_literal;
chain->link[idx].id = link_val->constant->values[0].u32[0];
} else {
chain->link[idx].mode = vtn_access_mode_id;
chain->link[idx].id = w[i];
}
if (glsl_type_is_struct(type->type)) {
assert(chain->link[idx].mode == vtn_access_mode_literal);
type = type->members[chain->link[idx].id];
} else {
type = type->array_element;
}
idx++;
}
ptr = ralloc(b, struct vtn_pointer);
*ptr = *base;
ptr->chain = chain;
ptr->type = type;
if (base_val->value_type == vtn_value_type_sampled_image) {
struct vtn_value *val =
vtn_push_value(b, w[2], vtn_value_type_sampled_image);
val->sampled_image = ralloc(b, struct vtn_sampled_image);
val->sampled_image->image = ptr;
val->sampled_image->image =
vtn_pointer_dereference(b, base_val->sampled_image->image, chain);
val->sampled_image->sampler = base_val->sampled_image->sampler;
} else {
assert(base_val->value_type == vtn_value_type_pointer);
struct vtn_value *val =
vtn_push_value(b, w[2], vtn_value_type_pointer);
val->pointer = ptr;
val->pointer = vtn_pointer_dereference(b, base_val->pointer, chain);
}
break;
}