matrices matrices matrices
This commit is contained in:
@@ -37,7 +37,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
|
||||
if (entry)
|
||||
return entry->data;
|
||||
|
||||
struct vtn_ssa_value *val = ralloc(b, struct vtn_ssa_value);
|
||||
struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
|
||||
val->type = type;
|
||||
|
||||
switch (glsl_get_base_type(type)) {
|
||||
@@ -63,7 +63,7 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
|
||||
val->elems = ralloc_array(b, struct vtn_ssa_value *, columns);
|
||||
|
||||
for (unsigned i = 0; i < columns; i++) {
|
||||
struct vtn_ssa_value *col_val = ralloc(b, struct vtn_ssa_value);
|
||||
struct vtn_ssa_value *col_val = rzalloc(b, struct vtn_ssa_value);
|
||||
col_val->type = glsl_get_column_type(val->type);
|
||||
nir_load_const_instr *load =
|
||||
nir_load_const_instr_create(b->shader, rows);
|
||||
@@ -516,6 +516,7 @@ var_decoration_cb(struct vtn_builder *b, struct vtn_value *val,
|
||||
case SpvDecorationFPFastMathMode:
|
||||
case SpvDecorationLinkageAttributes:
|
||||
case SpvDecorationSpecId:
|
||||
break;
|
||||
default:
|
||||
unreachable("Unhandled variable decoration");
|
||||
}
|
||||
@@ -525,7 +526,7 @@ static struct vtn_ssa_value *
|
||||
_vtn_variable_load(struct vtn_builder *b,
|
||||
nir_deref_var *src_deref, nir_deref *src_deref_tail)
|
||||
{
|
||||
struct vtn_ssa_value *val = ralloc(b, struct vtn_ssa_value);
|
||||
struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
|
||||
val->type = src_deref_tail->type;
|
||||
|
||||
/* The deref tail may contain a deref to select a component of a vector (in
|
||||
@@ -1010,11 +1011,264 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
|
||||
nir_builder_instr_insert(&b->nb, &instr->instr);
|
||||
}
|
||||
|
||||
static struct vtn_ssa_value *
|
||||
vtn_create_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
|
||||
{
|
||||
struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
|
||||
val->type = type;
|
||||
|
||||
if (!glsl_type_is_vector_or_scalar(type)) {
|
||||
unsigned elems = glsl_get_length(type);
|
||||
val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
|
||||
for (unsigned i = 0; i < elems; i++) {
|
||||
const struct glsl_type *child_type;
|
||||
|
||||
switch (glsl_get_base_type(type)) {
|
||||
case GLSL_TYPE_INT:
|
||||
case GLSL_TYPE_UINT:
|
||||
case GLSL_TYPE_BOOL:
|
||||
case GLSL_TYPE_FLOAT:
|
||||
case GLSL_TYPE_DOUBLE:
|
||||
child_type = glsl_get_column_type(type);
|
||||
break;
|
||||
case GLSL_TYPE_ARRAY:
|
||||
child_type = glsl_get_array_element(type);
|
||||
break;
|
||||
case GLSL_TYPE_STRUCT:
|
||||
child_type = glsl_get_struct_field(type, i);
|
||||
break;
|
||||
default:
|
||||
unreachable("unkown base type");
|
||||
}
|
||||
|
||||
val->elems[i] = vtn_create_ssa_value(b, child_type);
|
||||
}
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
static nir_alu_instr *
|
||||
create_vec(void *mem_ctx, unsigned num_components)
|
||||
{
|
||||
nir_op op;
|
||||
switch (num_components) {
|
||||
case 1: op = nir_op_fmov; break;
|
||||
case 2: op = nir_op_vec2; break;
|
||||
case 3: op = nir_op_vec3; break;
|
||||
case 4: op = nir_op_vec4; break;
|
||||
default: unreachable("bad vector size");
|
||||
}
|
||||
|
||||
nir_alu_instr *vec = nir_alu_instr_create(mem_ctx, op);
|
||||
nir_ssa_dest_init(&vec->instr, &vec->dest.dest, num_components, NULL);
|
||||
|
||||
return vec;
|
||||
}
|
||||
|
||||
static struct vtn_ssa_value *
|
||||
vtn_transpose(struct vtn_builder *b, struct vtn_ssa_value *src)
|
||||
{
|
||||
if (src->transposed)
|
||||
return src->transposed;
|
||||
|
||||
struct vtn_ssa_value *dest =
|
||||
vtn_create_ssa_value(b, glsl_transposed_type(src->type));
|
||||
|
||||
for (unsigned i = 0; i < glsl_get_matrix_columns(dest->type); i++) {
|
||||
nir_alu_instr *vec = create_vec(b, glsl_get_matrix_columns(src->type));
|
||||
if (glsl_type_is_vector_or_scalar(src->type)) {
|
||||
vec->src[0].src = nir_src_for_ssa(src->def);
|
||||
vec->src[0].swizzle[0] = i;
|
||||
} else {
|
||||
for (unsigned j = 0; j < glsl_get_matrix_columns(src->type); j++) {
|
||||
vec->src[j].src = nir_src_for_ssa(src->elems[j]->def);
|
||||
vec->src[j].swizzle[0] = i;
|
||||
}
|
||||
}
|
||||
nir_builder_instr_insert(&b->nb, &vec->instr);
|
||||
dest->elems[i]->def = &vec->dest.dest.ssa;
|
||||
}
|
||||
|
||||
dest->transposed = src;
|
||||
|
||||
return dest;
|
||||
}
|
||||
|
||||
/*
|
||||
* Normally, column vectors in SPIR-V correspond to a single NIR SSA
|
||||
* definition. But for matrix multiplies, we want to do one routine for
|
||||
* multiplying a matrix by a matrix and then pretend that vectors are matrices
|
||||
* with one column. So we "wrap" these things, and unwrap the result before we
|
||||
* send it off.
|
||||
*/
|
||||
|
||||
static struct vtn_ssa_value *
|
||||
vtn_wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
|
||||
{
|
||||
if (val == NULL)
|
||||
return NULL;
|
||||
|
||||
if (glsl_type_is_matrix(val->type))
|
||||
return val;
|
||||
|
||||
struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
|
||||
dest->type = val->type;
|
||||
dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
|
||||
dest->elems[0] = val;
|
||||
|
||||
return dest;
|
||||
}
|
||||
|
||||
static struct vtn_ssa_value *
|
||||
vtn_unwrap_matrix(struct vtn_ssa_value *val)
|
||||
{
|
||||
if (glsl_type_is_matrix(val->type))
|
||||
return val;
|
||||
|
||||
return val->elems[0];
|
||||
}
|
||||
|
||||
static struct vtn_ssa_value *
|
||||
vtn_matrix_multiply(struct vtn_builder *b,
|
||||
struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
|
||||
{
|
||||
|
||||
struct vtn_ssa_value *src0 = vtn_wrap_matrix(b, _src0);
|
||||
struct vtn_ssa_value *src1 = vtn_wrap_matrix(b, _src1);
|
||||
struct vtn_ssa_value *src0_transpose = vtn_wrap_matrix(b, _src0->transposed);
|
||||
struct vtn_ssa_value *src1_transpose = vtn_wrap_matrix(b, _src1->transposed);
|
||||
|
||||
unsigned src0_rows = glsl_get_vector_elements(src0->type);
|
||||
unsigned src0_columns = glsl_get_matrix_columns(src0->type);
|
||||
unsigned src1_columns = glsl_get_matrix_columns(src1->type);
|
||||
|
||||
struct vtn_ssa_value *dest =
|
||||
vtn_create_ssa_value(b, glsl_matrix_type(glsl_get_base_type(src0->type),
|
||||
src0_rows, src1_columns));
|
||||
|
||||
dest = vtn_wrap_matrix(b, dest);
|
||||
|
||||
bool transpose_result = false;
|
||||
if (src0_transpose && src1_transpose) {
|
||||
/* transpose(A) * transpose(B) = transpose(B * A) */
|
||||
src1 = src0_transpose;
|
||||
src0 = src1_transpose;
|
||||
src0_transpose = NULL;
|
||||
src1_transpose = NULL;
|
||||
transpose_result = true;
|
||||
}
|
||||
|
||||
if (src0_transpose && !src1_transpose &&
|
||||
glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
|
||||
/* We already have the rows of src0 and the columns of src1 available,
|
||||
* so we can just take the dot product of each row with each column to
|
||||
* get the result.
|
||||
*/
|
||||
|
||||
for (unsigned i = 0; i < src1_columns; i++) {
|
||||
nir_alu_instr *vec = create_vec(b, src0_rows);
|
||||
for (unsigned j = 0; j < src0_rows; j++) {
|
||||
vec->src[j].src =
|
||||
nir_src_for_ssa(nir_fdot(&b->nb, src0_transpose->elems[j]->def,
|
||||
src1->elems[i]->def));
|
||||
}
|
||||
|
||||
nir_builder_instr_insert(&b->nb, &vec->instr);
|
||||
dest->elems[i]->def = &vec->dest.dest.ssa;
|
||||
}
|
||||
} else {
|
||||
/* We don't handle the case where src1 is transposed but not src0, since
|
||||
* the general case only uses individual components of src1 so the
|
||||
* optimizer should chew through the transpose we emitted for src1.
|
||||
*/
|
||||
|
||||
for (unsigned i = 0; i < src1_columns; i++) {
|
||||
/* dest[i] = sum(src0[j] * src1[i][j] for all j) */
|
||||
dest->elems[i]->def =
|
||||
nir_fmul(&b->nb, src0->elems[0]->def,
|
||||
vtn_vector_extract(b, src1->elems[i]->def, 0));
|
||||
for (unsigned j = 1; j < src0_columns; j++) {
|
||||
dest->elems[i]->def =
|
||||
nir_fadd(&b->nb, dest->elems[i]->def,
|
||||
nir_fmul(&b->nb, src0->elems[j]->def,
|
||||
vtn_vector_extract(b,
|
||||
src1->elems[i]->def, j)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dest = vtn_unwrap_matrix(dest);
|
||||
|
||||
if (transpose_result)
|
||||
dest = vtn_transpose(b, dest);
|
||||
|
||||
return dest;
|
||||
}
|
||||
|
||||
static struct vtn_ssa_value *
|
||||
vtn_mat_times_scalar(struct vtn_builder *b,
|
||||
struct vtn_ssa_value *mat,
|
||||
nir_ssa_def *scalar)
|
||||
{
|
||||
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
|
||||
for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
|
||||
if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT)
|
||||
dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
|
||||
else
|
||||
dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
|
||||
}
|
||||
|
||||
return dest;
|
||||
}
|
||||
|
||||
static void
|
||||
vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
|
||||
const uint32_t *w, unsigned count)
|
||||
{
|
||||
unreachable("Matrix math not handled");
|
||||
struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
|
||||
val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
|
||||
|
||||
switch (opcode) {
|
||||
case SpvOpTranspose: {
|
||||
struct vtn_ssa_value *src = vtn_ssa_value(b, w[3]);
|
||||
val->ssa = vtn_transpose(b, src);
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpOuterProduct: {
|
||||
struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]);
|
||||
struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]);
|
||||
|
||||
val->ssa = vtn_matrix_multiply(b, src0, vtn_transpose(b, src1));
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpMatrixTimesScalar: {
|
||||
struct vtn_ssa_value *mat = vtn_ssa_value(b, w[3]);
|
||||
struct vtn_ssa_value *scalar = vtn_ssa_value(b, w[4]);
|
||||
|
||||
if (mat->transposed) {
|
||||
val->ssa = vtn_transpose(b, vtn_mat_times_scalar(b, mat->transposed,
|
||||
scalar->def));
|
||||
} else {
|
||||
val->ssa = vtn_mat_times_scalar(b, mat, scalar->def);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpVectorTimesMatrix:
|
||||
case SpvOpMatrixTimesVector:
|
||||
case SpvOpMatrixTimesMatrix: {
|
||||
struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]);
|
||||
struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]);
|
||||
|
||||
val->ssa = vtn_matrix_multiply(b, src0, src1);
|
||||
break;
|
||||
}
|
||||
|
||||
default: unreachable("unknown matrix opcode");
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
@@ -1197,29 +1451,10 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
||||
static nir_ssa_def *
|
||||
vtn_vector_extract(struct vtn_builder *b, nir_ssa_def *src, unsigned index)
|
||||
{
|
||||
nir_alu_src alu_src;
|
||||
alu_src.src = nir_src_for_ssa(src);
|
||||
alu_src.swizzle[0] = index;
|
||||
return nir_fmov_alu(&b->nb, alu_src, 1);
|
||||
unsigned swiz[4] = { index };
|
||||
return nir_swizzle(&b->nb, src, swiz, 1, true);
|
||||
}
|
||||
|
||||
static nir_alu_instr *
|
||||
create_vec(void *mem_ctx, unsigned num_components)
|
||||
{
|
||||
nir_op op;
|
||||
switch (num_components) {
|
||||
case 1: op = nir_op_fmov; break;
|
||||
case 2: op = nir_op_vec2; break;
|
||||
case 3: op = nir_op_vec3; break;
|
||||
case 4: op = nir_op_vec4; break;
|
||||
default: unreachable("bad vector size");
|
||||
}
|
||||
|
||||
nir_alu_instr *vec = nir_alu_instr_create(mem_ctx, op);
|
||||
nir_ssa_dest_init(&vec->instr, &vec->dest.dest, num_components, NULL);
|
||||
|
||||
return vec;
|
||||
}
|
||||
|
||||
static nir_ssa_def *
|
||||
vtn_vector_insert(struct vtn_builder *b, nir_ssa_def *src, nir_ssa_def *insert,
|
||||
@@ -1320,7 +1555,7 @@ vtn_vector_construct(struct vtn_builder *b, unsigned num_components,
|
||||
static struct vtn_ssa_value *
|
||||
vtn_composite_copy(void *mem_ctx, struct vtn_ssa_value *src)
|
||||
{
|
||||
struct vtn_ssa_value *dest = ralloc(mem_ctx, struct vtn_ssa_value);
|
||||
struct vtn_ssa_value *dest = rzalloc(mem_ctx, struct vtn_ssa_value);
|
||||
dest->type = src->type;
|
||||
|
||||
if (glsl_type_is_vector_or_scalar(src->type)) {
|
||||
@@ -1376,7 +1611,7 @@ vtn_composite_extract(struct vtn_builder *b, struct vtn_ssa_value *src,
|
||||
* vector to extract.
|
||||
*/
|
||||
|
||||
struct vtn_ssa_value *ret = ralloc(b, struct vtn_ssa_value);
|
||||
struct vtn_ssa_value *ret = rzalloc(b, struct vtn_ssa_value);
|
||||
ret->type = glsl_scalar_type(glsl_get_base_type(cur->type));
|
||||
ret->def = vtn_vector_extract(b, cur->def, indices[i]);
|
||||
return ret;
|
||||
@@ -1413,7 +1648,7 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
|
||||
break;
|
||||
|
||||
case SpvOpCompositeConstruct: {
|
||||
val->ssa = ralloc(b, struct vtn_ssa_value);
|
||||
val->ssa = rzalloc(b, struct vtn_ssa_value);
|
||||
unsigned elems = count - 3;
|
||||
if (glsl_type_is_vector_or_scalar(val->type)) {
|
||||
nir_ssa_def *srcs[4];
|
||||
|
@@ -71,6 +71,11 @@ struct vtn_ssa_value {
|
||||
struct vtn_ssa_value **elems;
|
||||
};
|
||||
|
||||
/* For matrices, a transposed version of the value, or NULL if it hasn't
|
||||
* been computed
|
||||
*/
|
||||
struct vtn_ssa_value *transposed;
|
||||
|
||||
const struct glsl_type *type;
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user