matrices matrices matrices
This commit is contained in:
@@ -37,7 +37,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
|
|||||||
if (entry)
|
if (entry)
|
||||||
return entry->data;
|
return entry->data;
|
||||||
|
|
||||||
struct vtn_ssa_value *val = ralloc(b, struct vtn_ssa_value);
|
struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
|
||||||
val->type = type;
|
val->type = type;
|
||||||
|
|
||||||
switch (glsl_get_base_type(type)) {
|
switch (glsl_get_base_type(type)) {
|
||||||
@@ -63,7 +63,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
|
|||||||
val->elems = ralloc_array(b, struct vtn_ssa_value *, columns);
|
val->elems = ralloc_array(b, struct vtn_ssa_value *, columns);
|
||||||
|
|
||||||
for (unsigned i = 0; i < columns; i++) {
|
for (unsigned i = 0; i < columns; i++) {
|
||||||
struct vtn_ssa_value *col_val = ralloc(b, struct vtn_ssa_value);
|
struct vtn_ssa_value *col_val = rzalloc(b, struct vtn_ssa_value);
|
||||||
col_val->type = glsl_get_column_type(val->type);
|
col_val->type = glsl_get_column_type(val->type);
|
||||||
nir_load_const_instr *load =
|
nir_load_const_instr *load =
|
||||||
nir_load_const_instr_create(b->shader, rows);
|
nir_load_const_instr_create(b->shader, rows);
|
||||||
@@ -516,6 +516,7 @@ var_decoration_cb(struct vtn_builder *b, struct vtn_value *val,
|
|||||||
case SpvDecorationFPFastMathMode:
|
case SpvDecorationFPFastMathMode:
|
||||||
case SpvDecorationLinkageAttributes:
|
case SpvDecorationLinkageAttributes:
|
||||||
case SpvDecorationSpecId:
|
case SpvDecorationSpecId:
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
unreachable("Unhandled variable decoration");
|
unreachable("Unhandled variable decoration");
|
||||||
}
|
}
|
||||||
@@ -525,7 +526,7 @@ static struct vtn_ssa_value *
|
|||||||
_vtn_variable_load(struct vtn_builder *b,
|
_vtn_variable_load(struct vtn_builder *b,
|
||||||
nir_deref_var *src_deref, nir_deref *src_deref_tail)
|
nir_deref_var *src_deref, nir_deref *src_deref_tail)
|
||||||
{
|
{
|
||||||
struct vtn_ssa_value *val = ralloc(b, struct vtn_ssa_value);
|
struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
|
||||||
val->type = src_deref_tail->type;
|
val->type = src_deref_tail->type;
|
||||||
|
|
||||||
/* The deref tail may contain a deref to select a component of a vector (in
|
/* The deref tail may contain a deref to select a component of a vector (in
|
||||||
@@ -1010,11 +1011,264 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
|
|||||||
nir_builder_instr_insert(&b->nb, &instr->instr);
|
nir_builder_instr_insert(&b->nb, &instr->instr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static struct vtn_ssa_value *
|
||||||
|
vtn_create_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
|
||||||
|
{
|
||||||
|
struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
|
||||||
|
val->type = type;
|
||||||
|
|
||||||
|
if (!glsl_type_is_vector_or_scalar(type)) {
|
||||||
|
unsigned elems = glsl_get_length(type);
|
||||||
|
val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
|
||||||
|
for (unsigned i = 0; i < elems; i++) {
|
||||||
|
const struct glsl_type *child_type;
|
||||||
|
|
||||||
|
switch (glsl_get_base_type(type)) {
|
||||||
|
case GLSL_TYPE_INT:
|
||||||
|
case GLSL_TYPE_UINT:
|
||||||
|
case GLSL_TYPE_BOOL:
|
||||||
|
case GLSL_TYPE_FLOAT:
|
||||||
|
case GLSL_TYPE_DOUBLE:
|
||||||
|
child_type = glsl_get_column_type(type);
|
||||||
|
break;
|
||||||
|
case GLSL_TYPE_ARRAY:
|
||||||
|
child_type = glsl_get_array_element(type);
|
||||||
|
break;
|
||||||
|
case GLSL_TYPE_STRUCT:
|
||||||
|
child_type = glsl_get_struct_field(type, i);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
unreachable("unkown base type");
|
||||||
|
}
|
||||||
|
|
||||||
|
val->elems[i] = vtn_create_ssa_value(b, child_type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
static nir_alu_instr *
|
||||||
|
create_vec(void *mem_ctx, unsigned num_components)
|
||||||
|
{
|
||||||
|
nir_op op;
|
||||||
|
switch (num_components) {
|
||||||
|
case 1: op = nir_op_fmov; break;
|
||||||
|
case 2: op = nir_op_vec2; break;
|
||||||
|
case 3: op = nir_op_vec3; break;
|
||||||
|
case 4: op = nir_op_vec4; break;
|
||||||
|
default: unreachable("bad vector size");
|
||||||
|
}
|
||||||
|
|
||||||
|
nir_alu_instr *vec = nir_alu_instr_create(mem_ctx, op);
|
||||||
|
nir_ssa_dest_init(&vec->instr, &vec->dest.dest, num_components, NULL);
|
||||||
|
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct vtn_ssa_value *
|
||||||
|
vtn_transpose(struct vtn_builder *b, struct vtn_ssa_value *src)
|
||||||
|
{
|
||||||
|
if (src->transposed)
|
||||||
|
return src->transposed;
|
||||||
|
|
||||||
|
struct vtn_ssa_value *dest =
|
||||||
|
vtn_create_ssa_value(b, glsl_transposed_type(src->type));
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < glsl_get_matrix_columns(dest->type); i++) {
|
||||||
|
nir_alu_instr *vec = create_vec(b, glsl_get_matrix_columns(src->type));
|
||||||
|
if (glsl_type_is_vector_or_scalar(src->type)) {
|
||||||
|
vec->src[0].src = nir_src_for_ssa(src->def);
|
||||||
|
vec->src[0].swizzle[0] = i;
|
||||||
|
} else {
|
||||||
|
for (unsigned j = 0; j < glsl_get_matrix_columns(src->type); j++) {
|
||||||
|
vec->src[j].src = nir_src_for_ssa(src->elems[j]->def);
|
||||||
|
vec->src[j].swizzle[0] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nir_builder_instr_insert(&b->nb, &vec->instr);
|
||||||
|
dest->elems[i]->def = &vec->dest.dest.ssa;
|
||||||
|
}
|
||||||
|
|
||||||
|
dest->transposed = src;
|
||||||
|
|
||||||
|
return dest;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Normally, column vectors in SPIR-V correspond to a single NIR SSA
|
||||||
|
* definition. But for matrix multiplies, we want to do one routine for
|
||||||
|
* multiplying a matrix by a matrix and then pretend that vectors are matrices
|
||||||
|
* with one column. So we "wrap" these things, and unwrap the result before we
|
||||||
|
* send it off.
|
||||||
|
*/
|
||||||
|
|
||||||
|
static struct vtn_ssa_value *
|
||||||
|
vtn_wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
|
||||||
|
{
|
||||||
|
if (val == NULL)
|
||||||
|
return NULL;
|
||||||
|
|
||||||
|
if (glsl_type_is_matrix(val->type))
|
||||||
|
return val;
|
||||||
|
|
||||||
|
struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
|
||||||
|
dest->type = val->type;
|
||||||
|
dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
|
||||||
|
dest->elems[0] = val;
|
||||||
|
|
||||||
|
return dest;
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct vtn_ssa_value *
|
||||||
|
vtn_unwrap_matrix(struct vtn_ssa_value *val)
|
||||||
|
{
|
||||||
|
if (glsl_type_is_matrix(val->type))
|
||||||
|
return val;
|
||||||
|
|
||||||
|
return val->elems[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct vtn_ssa_value *
|
||||||
|
vtn_matrix_multiply(struct vtn_builder *b,
|
||||||
|
struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
|
||||||
|
{
|
||||||
|
|
||||||
|
struct vtn_ssa_value *src0 = vtn_wrap_matrix(b, _src0);
|
||||||
|
struct vtn_ssa_value *src1 = vtn_wrap_matrix(b, _src1);
|
||||||
|
struct vtn_ssa_value *src0_transpose = vtn_wrap_matrix(b, _src0->transposed);
|
||||||
|
struct vtn_ssa_value *src1_transpose = vtn_wrap_matrix(b, _src1->transposed);
|
||||||
|
|
||||||
|
unsigned src0_rows = glsl_get_vector_elements(src0->type);
|
||||||
|
unsigned src0_columns = glsl_get_matrix_columns(src0->type);
|
||||||
|
unsigned src1_columns = glsl_get_matrix_columns(src1->type);
|
||||||
|
|
||||||
|
struct vtn_ssa_value *dest =
|
||||||
|
vtn_create_ssa_value(b, glsl_matrix_type(glsl_get_base_type(src0->type),
|
||||||
|
src0_rows, src1_columns));
|
||||||
|
|
||||||
|
dest = vtn_wrap_matrix(b, dest);
|
||||||
|
|
||||||
|
bool transpose_result = false;
|
||||||
|
if (src0_transpose && src1_transpose) {
|
||||||
|
/* transpose(A) * transpose(B) = transpose(B * A) */
|
||||||
|
src1 = src0_transpose;
|
||||||
|
src0 = src1_transpose;
|
||||||
|
src0_transpose = NULL;
|
||||||
|
src1_transpose = NULL;
|
||||||
|
transpose_result = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src0_transpose && !src1_transpose &&
|
||||||
|
glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
|
||||||
|
/* We already have the rows of src0 and the columns of src1 available,
|
||||||
|
* so we can just take the dot product of each row with each column to
|
||||||
|
* get the result.
|
||||||
|
*/
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < src1_columns; i++) {
|
||||||
|
nir_alu_instr *vec = create_vec(b, src0_rows);
|
||||||
|
for (unsigned j = 0; j < src0_rows; j++) {
|
||||||
|
vec->src[j].src =
|
||||||
|
nir_src_for_ssa(nir_fdot(&b->nb, src0_transpose->elems[j]->def,
|
||||||
|
src1->elems[i]->def));
|
||||||
|
}
|
||||||
|
|
||||||
|
nir_builder_instr_insert(&b->nb, &vec->instr);
|
||||||
|
dest->elems[i]->def = &vec->dest.dest.ssa;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
/* We don't handle the case where src1 is transposed but not src0, since
|
||||||
|
* the general case only uses individual components of src1 so the
|
||||||
|
* optimizer should chew through the transpose we emitted for src1.
|
||||||
|
*/
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < src1_columns; i++) {
|
||||||
|
/* dest[i] = sum(src0[j] * src1[i][j] for all j) */
|
||||||
|
dest->elems[i]->def =
|
||||||
|
nir_fmul(&b->nb, src0->elems[0]->def,
|
||||||
|
vtn_vector_extract(b, src1->elems[i]->def, 0));
|
||||||
|
for (unsigned j = 1; j < src0_columns; j++) {
|
||||||
|
dest->elems[i]->def =
|
||||||
|
nir_fadd(&b->nb, dest->elems[i]->def,
|
||||||
|
nir_fmul(&b->nb, src0->elems[j]->def,
|
||||||
|
vtn_vector_extract(b,
|
||||||
|
src1->elems[i]->def, j)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dest = vtn_unwrap_matrix(dest);
|
||||||
|
|
||||||
|
if (transpose_result)
|
||||||
|
dest = vtn_transpose(b, dest);
|
||||||
|
|
||||||
|
return dest;
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct vtn_ssa_value *
|
||||||
|
vtn_mat_times_scalar(struct vtn_builder *b,
|
||||||
|
struct vtn_ssa_value *mat,
|
||||||
|
nir_ssa_def *scalar)
|
||||||
|
{
|
||||||
|
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
|
||||||
|
for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
|
||||||
|
if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT)
|
||||||
|
dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
|
||||||
|
else
|
||||||
|
dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
|
||||||
|
}
|
||||||
|
|
||||||
|
return dest;
|
||||||
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
|
vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
|
||||||
const uint32_t *w, unsigned count)
|
const uint32_t *w, unsigned count)
|
||||||
{
|
{
|
||||||
unreachable("Matrix math not handled");
|
struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
|
||||||
|
val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
|
||||||
|
|
||||||
|
switch (opcode) {
|
||||||
|
case SpvOpTranspose: {
|
||||||
|
struct vtn_ssa_value *src = vtn_ssa_value(b, w[3]);
|
||||||
|
val->ssa = vtn_transpose(b, src);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case SpvOpOuterProduct: {
|
||||||
|
struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]);
|
||||||
|
struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]);
|
||||||
|
|
||||||
|
val->ssa = vtn_matrix_multiply(b, src0, vtn_transpose(b, src1));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case SpvOpMatrixTimesScalar: {
|
||||||
|
struct vtn_ssa_value *mat = vtn_ssa_value(b, w[3]);
|
||||||
|
struct vtn_ssa_value *scalar = vtn_ssa_value(b, w[4]);
|
||||||
|
|
||||||
|
if (mat->transposed) {
|
||||||
|
val->ssa = vtn_transpose(b, vtn_mat_times_scalar(b, mat->transposed,
|
||||||
|
scalar->def));
|
||||||
|
} else {
|
||||||
|
val->ssa = vtn_mat_times_scalar(b, mat, scalar->def);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case SpvOpVectorTimesMatrix:
|
||||||
|
case SpvOpMatrixTimesVector:
|
||||||
|
case SpvOpMatrixTimesMatrix: {
|
||||||
|
struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]);
|
||||||
|
struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]);
|
||||||
|
|
||||||
|
val->ssa = vtn_matrix_multiply(b, src0, src1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
default: unreachable("unknown matrix opcode");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
@@ -1197,29 +1451,10 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
|||||||
static nir_ssa_def *
|
static nir_ssa_def *
|
||||||
vtn_vector_extract(struct vtn_builder *b, nir_ssa_def *src, unsigned index)
|
vtn_vector_extract(struct vtn_builder *b, nir_ssa_def *src, unsigned index)
|
||||||
{
|
{
|
||||||
nir_alu_src alu_src;
|
unsigned swiz[4] = { index };
|
||||||
alu_src.src = nir_src_for_ssa(src);
|
return nir_swizzle(&b->nb, src, swiz, 1, true);
|
||||||
alu_src.swizzle[0] = index;
|
|
||||||
return nir_fmov_alu(&b->nb, alu_src, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static nir_alu_instr *
|
|
||||||
create_vec(void *mem_ctx, unsigned num_components)
|
|
||||||
{
|
|
||||||
nir_op op;
|
|
||||||
switch (num_components) {
|
|
||||||
case 1: op = nir_op_fmov; break;
|
|
||||||
case 2: op = nir_op_vec2; break;
|
|
||||||
case 3: op = nir_op_vec3; break;
|
|
||||||
case 4: op = nir_op_vec4; break;
|
|
||||||
default: unreachable("bad vector size");
|
|
||||||
}
|
|
||||||
|
|
||||||
nir_alu_instr *vec = nir_alu_instr_create(mem_ctx, op);
|
|
||||||
nir_ssa_dest_init(&vec->instr, &vec->dest.dest, num_components, NULL);
|
|
||||||
|
|
||||||
return vec;
|
|
||||||
}
|
|
||||||
|
|
||||||
static nir_ssa_def *
|
static nir_ssa_def *
|
||||||
vtn_vector_insert(struct vtn_builder *b, nir_ssa_def *src, nir_ssa_def *insert,
|
vtn_vector_insert(struct vtn_builder *b, nir_ssa_def *src, nir_ssa_def *insert,
|
||||||
@@ -1320,7 +1555,7 @@ vtn_vector_construct(struct vtn_builder *b, unsigned num_components,
|
|||||||
static struct vtn_ssa_value *
|
static struct vtn_ssa_value *
|
||||||
vtn_composite_copy(void *mem_ctx, struct vtn_ssa_value *src)
|
vtn_composite_copy(void *mem_ctx, struct vtn_ssa_value *src)
|
||||||
{
|
{
|
||||||
struct vtn_ssa_value *dest = ralloc(mem_ctx, struct vtn_ssa_value);
|
struct vtn_ssa_value *dest = rzalloc(mem_ctx, struct vtn_ssa_value);
|
||||||
dest->type = src->type;
|
dest->type = src->type;
|
||||||
|
|
||||||
if (glsl_type_is_vector_or_scalar(src->type)) {
|
if (glsl_type_is_vector_or_scalar(src->type)) {
|
||||||
@@ -1376,7 +1611,7 @@ vtn_composite_extract(struct vtn_builder *b, struct vtn_ssa_value *src,
|
|||||||
* vector to extract.
|
* vector to extract.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
struct vtn_ssa_value *ret = ralloc(b, struct vtn_ssa_value);
|
struct vtn_ssa_value *ret = rzalloc(b, struct vtn_ssa_value);
|
||||||
ret->type = glsl_scalar_type(glsl_get_base_type(cur->type));
|
ret->type = glsl_scalar_type(glsl_get_base_type(cur->type));
|
||||||
ret->def = vtn_vector_extract(b, cur->def, indices[i]);
|
ret->def = vtn_vector_extract(b, cur->def, indices[i]);
|
||||||
return ret;
|
return ret;
|
||||||
@@ -1413,7 +1648,7 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
|
|||||||
break;
|
break;
|
||||||
|
|
||||||
case SpvOpCompositeConstruct: {
|
case SpvOpCompositeConstruct: {
|
||||||
val->ssa = ralloc(b, struct vtn_ssa_value);
|
val->ssa = rzalloc(b, struct vtn_ssa_value);
|
||||||
unsigned elems = count - 3;
|
unsigned elems = count - 3;
|
||||||
if (glsl_type_is_vector_or_scalar(val->type)) {
|
if (glsl_type_is_vector_or_scalar(val->type)) {
|
||||||
nir_ssa_def *srcs[4];
|
nir_ssa_def *srcs[4];
|
||||||
|
@@ -71,6 +71,11 @@ struct vtn_ssa_value {
|
|||||||
struct vtn_ssa_value **elems;
|
struct vtn_ssa_value **elems;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* For matrices, a transposed version of the value, or NULL if it hasn't
|
||||||
|
* been computed
|
||||||
|
*/
|
||||||
|
struct vtn_ssa_value *transposed;
|
||||||
|
|
||||||
const struct glsl_type *type;
|
const struct glsl_type *type;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user