compiler/types: Add support for Cooperative Matrix types
Reviewed-by: Jesse Natalie <jenatali@microsoft.com> Reviewed-by: Ian Romanick <ian.d.romanick@intel.com> Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23825>
This commit is contained in:
@@ -1191,6 +1191,9 @@ do_comparison(void *mem_ctx, int operation, ir_rvalue *op0, ir_rvalue *op1)
|
||||
* ignores the sampler present in the type.
|
||||
*/
|
||||
break;
|
||||
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
unreachable("unsupported base type cooperative matrix");
|
||||
}
|
||||
|
||||
if (cmp == NULL)
|
||||
|
@@ -169,6 +169,8 @@ copy_constant_to_storage(union gl_constant_value *storage,
|
||||
*/
|
||||
assert(!"Should not get here.");
|
||||
break;
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
unreachable("unsupported base type cooperative matrix");
|
||||
}
|
||||
i += dmul;
|
||||
}
|
||||
|
@@ -370,6 +370,9 @@ ir_constant::clone(void *mem_ctx, struct hash_table *ht) const
|
||||
case GLSL_TYPE_INTERFACE:
|
||||
assert(!"Should not get here.");
|
||||
break;
|
||||
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
unreachable("unsupported base type cooperative matrix");
|
||||
}
|
||||
|
||||
return NULL;
|
||||
|
@@ -51,6 +51,7 @@ static struct {
|
||||
|
||||
hash_table *explicit_matrix_types;
|
||||
hash_table *array_types;
|
||||
hash_table *cmat_types;
|
||||
hash_table *struct_types;
|
||||
hash_table *interface_types;
|
||||
hash_table *subroutine_types;
|
||||
@@ -391,6 +392,7 @@ const glsl_type *glsl_type::get_bare_type() const
|
||||
return get_array_instance(this->fields.array->get_bare_type(),
|
||||
this->length);
|
||||
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
case GLSL_TYPE_SAMPLER:
|
||||
case GLSL_TYPE_TEXTURE:
|
||||
case GLSL_TYPE_IMAGE:
|
||||
@@ -527,6 +529,19 @@ make_array_type(linear_ctx *lin_ctx, const glsl_type *element_type, unsigned len
|
||||
return t;
|
||||
}
|
||||
|
||||
static const char *
|
||||
glsl_cmat_use_to_string(enum glsl_cmat_use use)
|
||||
{
|
||||
switch (use) {
|
||||
case GLSL_CMAT_USE_NONE: return "NONE";
|
||||
case GLSL_CMAT_USE_A: return "A";
|
||||
case GLSL_CMAT_USE_B: return "B";
|
||||
case GLSL_CMAT_USE_ACCUMULATOR: return "ACCUMULATOR";
|
||||
default:
|
||||
unreachable("invalid cooperative matrix use");
|
||||
}
|
||||
};
|
||||
|
||||
const glsl_type *
|
||||
glsl_type::vec(unsigned components, const glsl_type *const ts[])
|
||||
{
|
||||
@@ -1250,6 +1265,68 @@ glsl_type::get_array_instance(const glsl_type *element,
|
||||
return t;
|
||||
}
|
||||
|
||||
static const struct glsl_type *
|
||||
make_cmat_type(linear_ctx *lin_ctx, const struct glsl_cmat_description desc)
|
||||
{
|
||||
assert(lin_ctx != NULL);
|
||||
|
||||
struct glsl_type *t = linear_zalloc(lin_ctx, struct glsl_type);
|
||||
t->base_type = GLSL_TYPE_COOPERATIVE_MATRIX;
|
||||
t->sampled_type = GLSL_TYPE_VOID;
|
||||
t->vector_elements = 1;
|
||||
t->cmat_desc = desc;
|
||||
|
||||
const struct glsl_type *element_type = glsl_type::get_instance(desc.element_type, 1, 1);
|
||||
t->name_id = (uintptr_t ) linear_asprintf(lin_ctx, "coopmat<%s, %s, %u, %u, %s>",
|
||||
glsl_get_type_name(element_type),
|
||||
mesa_scope_name((mesa_scope)desc.scope),
|
||||
desc.rows, desc.cols,
|
||||
glsl_cmat_use_to_string((enum glsl_cmat_use)desc.use));
|
||||
|
||||
return t;
|
||||
}
|
||||
|
||||
const glsl_type *
|
||||
glsl_type::get_cmat_instance(const struct glsl_cmat_description desc)
|
||||
{
|
||||
STATIC_ASSERT(sizeof(struct glsl_cmat_description) == 4);
|
||||
|
||||
const uint32_t key = desc.element_type | desc.scope << 5 |
|
||||
desc.rows << 8 | desc.cols << 16 |
|
||||
desc.use << 24;
|
||||
const uint32_t key_hash = _mesa_hash_uint(&key);
|
||||
|
||||
simple_mtx_lock(&glsl_type_cache_mutex);
|
||||
assert(glsl_type_cache.users > 0);
|
||||
void *mem_ctx = glsl_type_cache.mem_ctx;
|
||||
|
||||
if (glsl_type_cache.cmat_types == NULL) {
|
||||
glsl_type_cache.cmat_types =
|
||||
_mesa_hash_table_create_u32_keys(mem_ctx);
|
||||
}
|
||||
hash_table *cmat_types = glsl_type_cache.cmat_types;
|
||||
|
||||
const struct hash_entry *entry = _mesa_hash_table_search_pre_hashed(
|
||||
cmat_types, key_hash, (void *) (uintptr_t) key);
|
||||
if (entry == NULL) {
|
||||
const struct glsl_type *t = make_cmat_type(glsl_type_cache.lin_ctx, desc);
|
||||
entry = _mesa_hash_table_insert_pre_hashed(cmat_types, key_hash,
|
||||
(void *) (uintptr_t) key, (void *) t);
|
||||
}
|
||||
|
||||
const struct glsl_type *t = (const struct glsl_type *)entry->data;
|
||||
simple_mtx_unlock(&glsl_type_cache_mutex);
|
||||
|
||||
assert(t->base_type == GLSL_TYPE_COOPERATIVE_MATRIX);
|
||||
assert(t->cmat_desc.element_type == desc.element_type);
|
||||
assert(t->cmat_desc.scope == desc.scope);
|
||||
assert(t->cmat_desc.rows == desc.rows);
|
||||
assert(t->cmat_desc.cols == desc.cols);
|
||||
assert(t->cmat_desc.use == desc.use);
|
||||
|
||||
return t;
|
||||
}
|
||||
|
||||
bool
|
||||
glsl_type::compare_no_precision(const glsl_type *b) const
|
||||
{
|
||||
@@ -1679,6 +1756,7 @@ glsl_type::component_slots() const
|
||||
case GLSL_TYPE_SUBROUTINE:
|
||||
return 1;
|
||||
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
case GLSL_TYPE_ATOMIC_UINT:
|
||||
case GLSL_TYPE_VOID:
|
||||
case GLSL_TYPE_ERROR:
|
||||
@@ -1745,6 +1823,7 @@ glsl_type::component_slots_aligned(unsigned offset) const
|
||||
case GLSL_TYPE_SUBROUTINE:
|
||||
return 1;
|
||||
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
case GLSL_TYPE_ATOMIC_UINT:
|
||||
case GLSL_TYPE_VOID:
|
||||
case GLSL_TYPE_ERROR:
|
||||
@@ -2599,6 +2678,10 @@ glsl_type::get_explicit_type_for_size_align(glsl_type_size_align_func type_info,
|
||||
type_info(this, size, alignment);
|
||||
assert(*alignment > 0);
|
||||
return this;
|
||||
} else if (this->is_cmat()) {
|
||||
*size = 0;
|
||||
*alignment = 0;
|
||||
return this;
|
||||
} else if (this->is_scalar()) {
|
||||
type_info(this, size, alignment);
|
||||
assert(*size == explicit_type_scalar_byte_size(this));
|
||||
@@ -2822,6 +2905,7 @@ glsl_type::count_vec4_slots(bool is_gl_vertex_input, bool is_bindless) const
|
||||
case GLSL_TYPE_SUBROUTINE:
|
||||
return 1;
|
||||
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
case GLSL_TYPE_ATOMIC_UINT:
|
||||
case GLSL_TYPE_VOID:
|
||||
case GLSL_TYPE_ERROR:
|
||||
@@ -2925,6 +3009,7 @@ union packed_type {
|
||||
unsigned length:13;
|
||||
unsigned explicit_stride:14;
|
||||
} array;
|
||||
glsl_cmat_description cmat_desc;
|
||||
struct {
|
||||
unsigned base_type:5;
|
||||
unsigned interface_packing_or_packed:2;
|
||||
@@ -3039,6 +3124,10 @@ encode_type_to_blob(struct blob *blob, const glsl_type *type)
|
||||
blob_write_uint32(blob, type->explicit_stride);
|
||||
encode_type_to_blob(blob, type->fields.array);
|
||||
return;
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
encoded.cmat_desc = type->cmat_desc;
|
||||
blob_write_uint32(blob, encoded.u32);
|
||||
return;
|
||||
case GLSL_TYPE_STRUCT:
|
||||
case GLSL_TYPE_INTERFACE:
|
||||
encoded.strct.length = MIN2(type->length, 0xfffff);
|
||||
@@ -3145,6 +3234,9 @@ decode_type_from_blob(struct blob_reader *blob)
|
||||
return glsl_type::get_array_instance(decode_type_from_blob(blob),
|
||||
length, explicit_stride);
|
||||
}
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX: {
|
||||
return glsl_type::get_cmat_instance(encoded.cmat_desc);
|
||||
}
|
||||
case GLSL_TYPE_STRUCT:
|
||||
case GLSL_TYPE_INTERFACE: {
|
||||
char *name = blob_read_string(blob);
|
||||
|
@@ -76,6 +76,7 @@ enum glsl_base_type {
|
||||
GLSL_TYPE_UINT64,
|
||||
GLSL_TYPE_INT64,
|
||||
GLSL_TYPE_BOOL,
|
||||
GLSL_TYPE_COOPERATIVE_MATRIX,
|
||||
GLSL_TYPE_SAMPLER,
|
||||
GLSL_TYPE_TEXTURE,
|
||||
GLSL_TYPE_IMAGE,
|
||||
@@ -167,6 +168,7 @@ glsl_base_type_get_bit_size(const enum glsl_base_type base_type)
|
||||
case GLSL_TYPE_UINT:
|
||||
case GLSL_TYPE_FLOAT: /* TODO handle mediump */
|
||||
case GLSL_TYPE_SUBROUTINE:
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
return 32;
|
||||
|
||||
case GLSL_TYPE_FLOAT16:
|
||||
@@ -279,6 +281,24 @@ enum {
|
||||
GLSL_PRECISION_LOW
|
||||
};
|
||||
|
||||
enum glsl_cmat_use {
|
||||
GLSL_CMAT_USE_NONE = 0,
|
||||
GLSL_CMAT_USE_A,
|
||||
GLSL_CMAT_USE_B,
|
||||
GLSL_CMAT_USE_ACCUMULATOR,
|
||||
};
|
||||
|
||||
struct glsl_cmat_description {
|
||||
/* MSVC can't merge bitfields of different types and also sign extend enums,
|
||||
* so use uint8_t for those cases.
|
||||
*/
|
||||
uint8_t element_type:5; /* enum glsl_base_type */
|
||||
uint8_t scope:3; /* mesa_scope */
|
||||
uint8_t rows;
|
||||
uint8_t cols;
|
||||
uint8_t use; /* enum glsl_cmat_use */
|
||||
};
|
||||
|
||||
const char *glsl_get_type_name(const struct glsl_type *type);
|
||||
|
||||
struct glsl_type {
|
||||
@@ -297,6 +317,8 @@ struct glsl_type {
|
||||
unsigned interface_packing:2;
|
||||
unsigned interface_row_major:1;
|
||||
|
||||
struct glsl_cmat_description cmat_desc;
|
||||
|
||||
/**
|
||||
* For \c GLSL_TYPE_STRUCT this specifies if the struct is packed or not.
|
||||
*
|
||||
@@ -456,6 +478,11 @@ struct glsl_type {
|
||||
unsigned array_size,
|
||||
unsigned explicit_stride = 0);
|
||||
|
||||
/**
|
||||
* Get the instance of a cooperative matrix type
|
||||
*/
|
||||
static const glsl_type *get_cmat_instance(const struct glsl_cmat_description desc);
|
||||
|
||||
/**
|
||||
* Get the instance of a record type
|
||||
*/
|
||||
@@ -931,6 +958,11 @@ struct glsl_type {
|
||||
return is_array() && fields.array->is_array();
|
||||
}
|
||||
|
||||
bool is_cmat() const
|
||||
{
|
||||
return base_type == GLSL_TYPE_COOPERATIVE_MATRIX;
|
||||
}
|
||||
|
||||
/**
|
||||
* Query whether or not a type is a record
|
||||
*/
|
||||
|
@@ -2755,6 +2755,7 @@ nir_get_nir_type_for_glsl_base_type(enum glsl_base_type base_type)
|
||||
case GLSL_TYPE_DOUBLE: return nir_type_float64;
|
||||
/* clang-format on */
|
||||
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
case GLSL_TYPE_SAMPLER:
|
||||
case GLSL_TYPE_TEXTURE:
|
||||
case GLSL_TYPE_IMAGE:
|
||||
|
@@ -335,6 +335,12 @@ glsl_type_is_array_or_matrix(const struct glsl_type *type)
|
||||
return type->is_array() || type->is_matrix();
|
||||
}
|
||||
|
||||
bool
|
||||
glsl_type_is_cmat(const struct glsl_type *type)
|
||||
{
|
||||
return type->is_cmat();
|
||||
}
|
||||
|
||||
bool
|
||||
glsl_type_is_struct(const struct glsl_type *type)
|
||||
{
|
||||
@@ -642,6 +648,12 @@ glsl_array_type(const glsl_type *element, unsigned array_size,
|
||||
return glsl_type::get_array_instance(element, array_size, explicit_stride);
|
||||
}
|
||||
|
||||
const glsl_type *
|
||||
glsl_cmat_type(const glsl_cmat_description *desc)
|
||||
{
|
||||
return glsl_type::get_cmat_instance(*desc);
|
||||
}
|
||||
|
||||
const glsl_type *
|
||||
glsl_replace_vector_type(const glsl_type *t, unsigned components)
|
||||
{
|
||||
@@ -857,6 +869,7 @@ glsl_get_natural_size_align_bytes(const struct glsl_type *type,
|
||||
|
||||
case GLSL_TYPE_ATOMIC_UINT:
|
||||
case GLSL_TYPE_SUBROUTINE:
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
case GLSL_TYPE_VOID:
|
||||
case GLSL_TYPE_ERROR:
|
||||
unreachable("type does not have a natural size");
|
||||
@@ -910,6 +923,7 @@ glsl_get_vec4_size_align_bytes(const struct glsl_type *type,
|
||||
case GLSL_TYPE_IMAGE:
|
||||
case GLSL_TYPE_ATOMIC_UINT:
|
||||
case GLSL_TYPE_SUBROUTINE:
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
case GLSL_TYPE_VOID:
|
||||
case GLSL_TYPE_ERROR:
|
||||
unreachable("type does not make sense for glsl_get_vec4_size_align_bytes()");
|
||||
@@ -1102,3 +1116,17 @@ glsl_type_replace_vec3_with_vec4(const struct glsl_type *type)
|
||||
{
|
||||
return type->replace_vec3_with_vec4();
|
||||
}
|
||||
|
||||
const struct glsl_type *
|
||||
glsl_get_cmat_element(const struct glsl_type *type)
|
||||
{
|
||||
assert(type->base_type == GLSL_TYPE_COOPERATIVE_MATRIX);
|
||||
return glsl_type::get_instance(type->cmat_desc.element_type, 1, 1);
|
||||
}
|
||||
|
||||
const struct glsl_cmat_description *
|
||||
glsl_get_cmat_description(const struct glsl_type *type)
|
||||
{
|
||||
assert(type->base_type == GLSL_TYPE_COOPERATIVE_MATRIX);
|
||||
return &type->cmat_desc;
|
||||
}
|
||||
|
@@ -140,6 +140,7 @@ bool glsl_type_is_array(const struct glsl_type *type);
|
||||
bool glsl_type_is_unsized_array(const struct glsl_type *type);
|
||||
bool glsl_type_is_array_of_arrays(const struct glsl_type *type);
|
||||
bool glsl_type_is_array_or_matrix(const struct glsl_type *type);
|
||||
bool glsl_type_is_cmat(const struct glsl_type *type);
|
||||
bool glsl_type_is_struct(const struct glsl_type *type);
|
||||
bool glsl_type_is_interface(const struct glsl_type *type);
|
||||
bool glsl_type_is_struct_or_ifc(const struct glsl_type *type);
|
||||
@@ -201,6 +202,8 @@ const struct glsl_type *glsl_array_type(const struct glsl_type *element,
|
||||
unsigned array_size,
|
||||
unsigned explicit_stride);
|
||||
|
||||
const struct glsl_type *glsl_cmat_type(const struct glsl_cmat_description *desc);
|
||||
|
||||
const struct glsl_type *glsl_struct_type(const struct glsl_struct_field *fields,
|
||||
unsigned num_fields, const char *name,
|
||||
bool packed);
|
||||
@@ -254,6 +257,9 @@ int glsl_get_field_index(const struct glsl_type *type, const char *name);
|
||||
|
||||
bool glsl_type_is_leaf(const struct glsl_type *type);
|
||||
|
||||
const struct glsl_type *glsl_get_cmat_element(const struct glsl_type *type);
|
||||
const struct glsl_cmat_description *glsl_get_cmat_description(const struct glsl_type *type);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@@ -74,6 +74,7 @@ brw_type_for_base_type(const struct glsl_type *type)
|
||||
return BRW_REGISTER_TYPE_Q;
|
||||
case GLSL_TYPE_VOID:
|
||||
case GLSL_TYPE_ERROR:
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
unreachable("not reached");
|
||||
}
|
||||
|
||||
|
@@ -622,6 +622,7 @@ type_size_xvec4(const struct glsl_type *type, bool as_vec4, bool bindless)
|
||||
return bindless ? 1 : DIV_ROUND_UP(BRW_IMAGE_PARAM_SIZE, 4);
|
||||
case GLSL_TYPE_VOID:
|
||||
case GLSL_TYPE_ERROR:
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
unreachable("not reached");
|
||||
}
|
||||
|
||||
|
@@ -1011,6 +1011,7 @@ associate_uniform_storage(struct gl_context *ctx,
|
||||
case GLSL_TYPE_STRUCT:
|
||||
case GLSL_TYPE_ERROR:
|
||||
case GLSL_TYPE_INTERFACE:
|
||||
case GLSL_TYPE_COOPERATIVE_MATRIX:
|
||||
assert(!"Should not get here.");
|
||||
break;
|
||||
}
|
||||
|
Reference in New Issue
Block a user