matrices matrices matrices

This commit is contained in:
Connor Abbott
2015-06-18 18:52:44 -07:00
parent d0fc04aacf
commit 841aab6f50
2 changed files with 268 additions and 28 deletions

View File

@@ -37,7 +37,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
if (entry) if (entry)
return entry->data; return entry->data;
struct vtn_ssa_value *val = ralloc(b, struct vtn_ssa_value); struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
val->type = type; val->type = type;
switch (glsl_get_base_type(type)) { switch (glsl_get_base_type(type)) {
@@ -63,7 +63,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
val->elems = ralloc_array(b, struct vtn_ssa_value *, columns); val->elems = ralloc_array(b, struct vtn_ssa_value *, columns);
for (unsigned i = 0; i < columns; i++) { for (unsigned i = 0; i < columns; i++) {
struct vtn_ssa_value *col_val = ralloc(b, struct vtn_ssa_value); struct vtn_ssa_value *col_val = rzalloc(b, struct vtn_ssa_value);
col_val->type = glsl_get_column_type(val->type); col_val->type = glsl_get_column_type(val->type);
nir_load_const_instr *load = nir_load_const_instr *load =
nir_load_const_instr_create(b->shader, rows); nir_load_const_instr_create(b->shader, rows);
@@ -516,6 +516,7 @@ var_decoration_cb(struct vtn_builder *b, struct vtn_value *val,
case SpvDecorationFPFastMathMode: case SpvDecorationFPFastMathMode:
case SpvDecorationLinkageAttributes: case SpvDecorationLinkageAttributes:
case SpvDecorationSpecId: case SpvDecorationSpecId:
break;
default: default:
unreachable("Unhandled variable decoration"); unreachable("Unhandled variable decoration");
} }
@@ -525,7 +526,7 @@ static struct vtn_ssa_value *
_vtn_variable_load(struct vtn_builder *b, _vtn_variable_load(struct vtn_builder *b,
nir_deref_var *src_deref, nir_deref *src_deref_tail) nir_deref_var *src_deref, nir_deref *src_deref_tail)
{ {
struct vtn_ssa_value *val = ralloc(b, struct vtn_ssa_value); struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
val->type = src_deref_tail->type; val->type = src_deref_tail->type;
/* The deref tail may contain a deref to select a component of a vector (in /* The deref tail may contain a deref to select a component of a vector (in
@@ -1010,11 +1011,264 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
nir_builder_instr_insert(&b->nb, &instr->instr); nir_builder_instr_insert(&b->nb, &instr->instr);
} }
static struct vtn_ssa_value *
vtn_create_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
{
struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
val->type = type;
if (!glsl_type_is_vector_or_scalar(type)) {
unsigned elems = glsl_get_length(type);
val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
for (unsigned i = 0; i < elems; i++) {
const struct glsl_type *child_type;
switch (glsl_get_base_type(type)) {
case GLSL_TYPE_INT:
case GLSL_TYPE_UINT:
case GLSL_TYPE_BOOL:
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_DOUBLE:
child_type = glsl_get_column_type(type);
break;
case GLSL_TYPE_ARRAY:
child_type = glsl_get_array_element(type);
break;
case GLSL_TYPE_STRUCT:
child_type = glsl_get_struct_field(type, i);
break;
default:
unreachable("unkown base type");
}
val->elems[i] = vtn_create_ssa_value(b, child_type);
}
}
return val;
}
static nir_alu_instr *
create_vec(void *mem_ctx, unsigned num_components)
{
nir_op op;
switch (num_components) {
case 1: op = nir_op_fmov; break;
case 2: op = nir_op_vec2; break;
case 3: op = nir_op_vec3; break;
case 4: op = nir_op_vec4; break;
default: unreachable("bad vector size");
}
nir_alu_instr *vec = nir_alu_instr_create(mem_ctx, op);
nir_ssa_dest_init(&vec->instr, &vec->dest.dest, num_components, NULL);
return vec;
}
static struct vtn_ssa_value *
vtn_transpose(struct vtn_builder *b, struct vtn_ssa_value *src)
{
if (src->transposed)
return src->transposed;
struct vtn_ssa_value *dest =
vtn_create_ssa_value(b, glsl_transposed_type(src->type));
for (unsigned i = 0; i < glsl_get_matrix_columns(dest->type); i++) {
nir_alu_instr *vec = create_vec(b, glsl_get_matrix_columns(src->type));
if (glsl_type_is_vector_or_scalar(src->type)) {
vec->src[0].src = nir_src_for_ssa(src->def);
vec->src[0].swizzle[0] = i;
} else {
for (unsigned j = 0; j < glsl_get_matrix_columns(src->type); j++) {
vec->src[j].src = nir_src_for_ssa(src->elems[j]->def);
vec->src[j].swizzle[0] = i;
}
}
nir_builder_instr_insert(&b->nb, &vec->instr);
dest->elems[i]->def = &vec->dest.dest.ssa;
}
dest->transposed = src;
return dest;
}
/*
* Normally, column vectors in SPIR-V correspond to a single NIR SSA
* definition. But for matrix multiplies, we want to do one routine for
* multiplying a matrix by a matrix and then pretend that vectors are matrices
* with one column. So we "wrap" these things, and unwrap the result before we
* send it off.
*/
static struct vtn_ssa_value *
vtn_wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
{
if (val == NULL)
return NULL;
if (glsl_type_is_matrix(val->type))
return val;
struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
dest->type = val->type;
dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
dest->elems[0] = val;
return dest;
}
static struct vtn_ssa_value *
vtn_unwrap_matrix(struct vtn_ssa_value *val)
{
if (glsl_type_is_matrix(val->type))
return val;
return val->elems[0];
}
static struct vtn_ssa_value *
vtn_matrix_multiply(struct vtn_builder *b,
struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
{
struct vtn_ssa_value *src0 = vtn_wrap_matrix(b, _src0);
struct vtn_ssa_value *src1 = vtn_wrap_matrix(b, _src1);
struct vtn_ssa_value *src0_transpose = vtn_wrap_matrix(b, _src0->transposed);
struct vtn_ssa_value *src1_transpose = vtn_wrap_matrix(b, _src1->transposed);
unsigned src0_rows = glsl_get_vector_elements(src0->type);
unsigned src0_columns = glsl_get_matrix_columns(src0->type);
unsigned src1_columns = glsl_get_matrix_columns(src1->type);
struct vtn_ssa_value *dest =
vtn_create_ssa_value(b, glsl_matrix_type(glsl_get_base_type(src0->type),
src0_rows, src1_columns));
dest = vtn_wrap_matrix(b, dest);
bool transpose_result = false;
if (src0_transpose && src1_transpose) {
/* transpose(A) * transpose(B) = transpose(B * A) */
src1 = src0_transpose;
src0 = src1_transpose;
src0_transpose = NULL;
src1_transpose = NULL;
transpose_result = true;
}
if (src0_transpose && !src1_transpose &&
glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
/* We already have the rows of src0 and the columns of src1 available,
* so we can just take the dot product of each row with each column to
* get the result.
*/
for (unsigned i = 0; i < src1_columns; i++) {
nir_alu_instr *vec = create_vec(b, src0_rows);
for (unsigned j = 0; j < src0_rows; j++) {
vec->src[j].src =
nir_src_for_ssa(nir_fdot(&b->nb, src0_transpose->elems[j]->def,
src1->elems[i]->def));
}
nir_builder_instr_insert(&b->nb, &vec->instr);
dest->elems[i]->def = &vec->dest.dest.ssa;
}
} else {
/* We don't handle the case where src1 is transposed but not src0, since
* the general case only uses individual components of src1 so the
* optimizer should chew through the transpose we emitted for src1.
*/
for (unsigned i = 0; i < src1_columns; i++) {
/* dest[i] = sum(src0[j] * src1[i][j] for all j) */
dest->elems[i]->def =
nir_fmul(&b->nb, src0->elems[0]->def,
vtn_vector_extract(b, src1->elems[i]->def, 0));
for (unsigned j = 1; j < src0_columns; j++) {
dest->elems[i]->def =
nir_fadd(&b->nb, dest->elems[i]->def,
nir_fmul(&b->nb, src0->elems[j]->def,
vtn_vector_extract(b,
src1->elems[i]->def, j)));
}
}
}
dest = vtn_unwrap_matrix(dest);
if (transpose_result)
dest = vtn_transpose(b, dest);
return dest;
}
static struct vtn_ssa_value *
vtn_mat_times_scalar(struct vtn_builder *b,
struct vtn_ssa_value *mat,
nir_ssa_def *scalar)
{
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT)
dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
else
dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
}
return dest;
}
static void static void
vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count) const uint32_t *w, unsigned count)
{ {
unreachable("Matrix math not handled"); struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
switch (opcode) {
case SpvOpTranspose: {
struct vtn_ssa_value *src = vtn_ssa_value(b, w[3]);
val->ssa = vtn_transpose(b, src);
break;
}
case SpvOpOuterProduct: {
struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]);
struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]);
val->ssa = vtn_matrix_multiply(b, src0, vtn_transpose(b, src1));
break;
}
case SpvOpMatrixTimesScalar: {
struct vtn_ssa_value *mat = vtn_ssa_value(b, w[3]);
struct vtn_ssa_value *scalar = vtn_ssa_value(b, w[4]);
if (mat->transposed) {
val->ssa = vtn_transpose(b, vtn_mat_times_scalar(b, mat->transposed,
scalar->def));
} else {
val->ssa = vtn_mat_times_scalar(b, mat, scalar->def);
}
break;
}
case SpvOpVectorTimesMatrix:
case SpvOpMatrixTimesVector:
case SpvOpMatrixTimesMatrix: {
struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]);
struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]);
val->ssa = vtn_matrix_multiply(b, src0, src1);
break;
}
default: unreachable("unknown matrix opcode");
}
} }
static void static void
@@ -1197,29 +1451,10 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
static nir_ssa_def * static nir_ssa_def *
vtn_vector_extract(struct vtn_builder *b, nir_ssa_def *src, unsigned index) vtn_vector_extract(struct vtn_builder *b, nir_ssa_def *src, unsigned index)
{ {
nir_alu_src alu_src; unsigned swiz[4] = { index };
alu_src.src = nir_src_for_ssa(src); return nir_swizzle(&b->nb, src, swiz, 1, true);
alu_src.swizzle[0] = index;
return nir_fmov_alu(&b->nb, alu_src, 1);
} }
static nir_alu_instr *
create_vec(void *mem_ctx, unsigned num_components)
{
nir_op op;
switch (num_components) {
case 1: op = nir_op_fmov; break;
case 2: op = nir_op_vec2; break;
case 3: op = nir_op_vec3; break;
case 4: op = nir_op_vec4; break;
default: unreachable("bad vector size");
}
nir_alu_instr *vec = nir_alu_instr_create(mem_ctx, op);
nir_ssa_dest_init(&vec->instr, &vec->dest.dest, num_components, NULL);
return vec;
}
static nir_ssa_def * static nir_ssa_def *
vtn_vector_insert(struct vtn_builder *b, nir_ssa_def *src, nir_ssa_def *insert, vtn_vector_insert(struct vtn_builder *b, nir_ssa_def *src, nir_ssa_def *insert,
@@ -1320,7 +1555,7 @@ vtn_vector_construct(struct vtn_builder *b, unsigned num_components,
static struct vtn_ssa_value * static struct vtn_ssa_value *
vtn_composite_copy(void *mem_ctx, struct vtn_ssa_value *src) vtn_composite_copy(void *mem_ctx, struct vtn_ssa_value *src)
{ {
struct vtn_ssa_value *dest = ralloc(mem_ctx, struct vtn_ssa_value); struct vtn_ssa_value *dest = rzalloc(mem_ctx, struct vtn_ssa_value);
dest->type = src->type; dest->type = src->type;
if (glsl_type_is_vector_or_scalar(src->type)) { if (glsl_type_is_vector_or_scalar(src->type)) {
@@ -1376,7 +1611,7 @@ vtn_composite_extract(struct vtn_builder *b, struct vtn_ssa_value *src,
* vector to extract. * vector to extract.
*/ */
struct vtn_ssa_value *ret = ralloc(b, struct vtn_ssa_value); struct vtn_ssa_value *ret = rzalloc(b, struct vtn_ssa_value);
ret->type = glsl_scalar_type(glsl_get_base_type(cur->type)); ret->type = glsl_scalar_type(glsl_get_base_type(cur->type));
ret->def = vtn_vector_extract(b, cur->def, indices[i]); ret->def = vtn_vector_extract(b, cur->def, indices[i]);
return ret; return ret;
@@ -1413,7 +1648,7 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
break; break;
case SpvOpCompositeConstruct: { case SpvOpCompositeConstruct: {
val->ssa = ralloc(b, struct vtn_ssa_value); val->ssa = rzalloc(b, struct vtn_ssa_value);
unsigned elems = count - 3; unsigned elems = count - 3;
if (glsl_type_is_vector_or_scalar(val->type)) { if (glsl_type_is_vector_or_scalar(val->type)) {
nir_ssa_def *srcs[4]; nir_ssa_def *srcs[4];

View File

@@ -71,6 +71,11 @@ struct vtn_ssa_value {
struct vtn_ssa_value **elems; struct vtn_ssa_value **elems;
}; };
/* For matrices, a transposed version of the value, or NULL if it hasn't
* been computed
*/
struct vtn_ssa_value *transposed;
const struct glsl_type *type; const struct glsl_type *type;
}; };