nir: Add new variable mode for task/mesh payload.

Task shader outputs work differently than other shaders, so they
need special consideration. Essentially, they have two kinds of
outputs:

1. Number of mesh shader workgroups to launch.
Will be still represented by a shader output.

2. Optional payload of up to (at least) 16K bytes.
These payload variables behave similarly to shared memory, but
the spec doesn't actually define them as shared memory (also, they
may be implemented differently by each backend), so we need to add
a new NIR variable mode for them.

These payload variables can't be represented by shader outputs
because the 16K bytes don't fit the 32x vec4 model that NIR uses
for its output variables.

This patch adds a new NIR variable mode: nir_var_mem_task_payload
and corresponding explicit I/O intrinsics, as well as support for
this new mode in nir_lower_io.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Caio Oliveira <caio.oliveira@intel.com>
Reviewed-by: Jason Ekstrand <jason.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14930>
This commit is contained in:
Timur Kristóf
2022-02-08 02:55:18 +01:00
committed by Marge Bot
parent d2d6eca081
commit f629fbd778
7 changed files with 51 additions and 10 deletions

View File

@@ -285,6 +285,7 @@ nir_shader_add_variable(nir_shader *shader, nir_variable *var)
case nir_var_mem_constant:
case nir_var_shader_call_data:
case nir_var_ray_hit_attrib:
case nir_var_mem_task_payload:
break;
case nir_var_mem_global:

View File

@@ -185,14 +185,15 @@ typedef enum {
nir_var_mem_push_const = (1 << 8),
nir_var_mem_ssbo = (1 << 9),
nir_var_mem_constant = (1 << 10),
nir_var_mem_task_payload = (1 << 11),
/* Generic modes intentionally come last. See encode_dref_modes() in
* nir_serialize.c for more details.
*/
nir_var_shader_temp = (1 << 11),
nir_var_function_temp = (1 << 12),
nir_var_mem_shared = (1 << 13),
nir_var_mem_global = (1 << 14),
nir_var_shader_temp = (1 << 12),
nir_var_function_temp = (1 << 13),
nir_var_mem_shared = (1 << 14),
nir_var_mem_global = (1 << 15),
nir_var_mem_generic = (nir_var_shader_temp |
nir_var_function_temp |
@@ -206,7 +207,7 @@ typedef enum {
nir_var_vec_indexable_modes = nir_var_mem_ubo | nir_var_mem_ssbo |
nir_var_mem_shared | nir_var_mem_global |
nir_var_mem_push_const,
nir_num_variable_modes = 15,
nir_num_variable_modes = 16,
nir_var_all = (1 << nir_num_variable_modes) - 1,
} nir_variable_mode;
MESA_DEFINE_CPP_ENUM_BITFIELD_OPERATORS(nir_variable_mode)

View File

@@ -970,6 +970,8 @@ load("per_primitive_output", [1, 1], [BASE, COMPONENT, DEST_TYPE, IO_SEMANTICS],
# src[] = { offset }.
load("shared", [1], [BASE, ALIGN_MUL, ALIGN_OFFSET], [CAN_ELIMINATE])
# src[] = { offset }.
load("task_payload", [1], [BASE, ALIGN_MUL, ALIGN_OFFSET], [CAN_ELIMINATE])
# src[] = { offset }.
load("push_constant", [1], [BASE, RANGE], [CAN_ELIMINATE, CAN_REORDER])
# src[] = { offset }.
load("constant", [1], [BASE, RANGE, ALIGN_MUL, ALIGN_OFFSET],
@@ -1008,6 +1010,8 @@ store("per_primitive_output", [1, 1], [BASE, WRITE_MASK, COMPONENT, SRC_TYPE, IO
store("ssbo", [-1, 1], [WRITE_MASK, ACCESS, ALIGN_MUL, ALIGN_OFFSET])
# src[] = { value, offset }.
store("shared", [1], [BASE, WRITE_MASK, ALIGN_MUL, ALIGN_OFFSET])
# src[] = { value, offset }.
store("task_payload", [1], [BASE, WRITE_MASK, ALIGN_MUL, ALIGN_OFFSET])
# src[] = { value, address }.
store("global", [1], [WRITE_MASK, ACCESS, ALIGN_MUL, ALIGN_OFFSET])
# src[] = { value, offset }.

View File

@@ -878,6 +878,7 @@ build_addr_for_var(nir_builder *b, nir_variable *var,
nir_address_format addr_format)
{
assert(var->data.mode & (nir_var_uniform | nir_var_mem_shared |
nir_var_mem_task_payload |
nir_var_shader_temp | nir_var_function_temp |
nir_var_mem_push_const | nir_var_mem_constant));
@@ -1332,6 +1333,10 @@ build_explicit_io_load(nir_builder *b, nir_intrinsic_instr *intrin,
assert(addr_format_is_offset(addr_format, mode));
op = nir_intrinsic_load_shared;
break;
case nir_var_mem_task_payload:
assert(addr_format_is_offset(addr_format, mode));
op = nir_intrinsic_load_task_payload;
break;
case nir_var_shader_temp:
case nir_var_function_temp:
if (addr_format_is_offset(addr_format, mode)) {
@@ -1554,6 +1559,10 @@ build_explicit_io_store(nir_builder *b, nir_intrinsic_instr *intrin,
assert(addr_format_is_offset(addr_format, mode));
op = nir_intrinsic_store_shared;
break;
case nir_var_mem_task_payload:
assert(addr_format_is_offset(addr_format, mode));
op = nir_intrinsic_store_task_payload;
break;
case nir_var_shader_temp:
case nir_var_function_temp:
if (addr_format_is_offset(addr_format, mode)) {
@@ -2308,6 +2317,9 @@ lower_vars_to_explicit(nir_shader *shader,
case nir_var_mem_shared:
offset = shader->info.shared_size;
break;
case nir_var_mem_task_payload:
offset = shader->info.task_payload_size;
break;
case nir_var_mem_constant:
offset = shader->constant_data_size;
break;
@@ -2351,6 +2363,9 @@ lower_vars_to_explicit(nir_shader *shader,
case nir_var_mem_shared:
shader->info.shared_size = offset;
break;
case nir_var_mem_task_payload:
shader->info.task_payload_size = offset;
break;
case nir_var_mem_constant:
shader->constant_data_size = offset;
break;
@@ -2381,7 +2396,8 @@ nir_lower_vars_to_explicit_types(nir_shader *shader,
ASSERTED nir_variable_mode supported =
nir_var_mem_shared | nir_var_mem_global | nir_var_mem_constant |
nir_var_shader_temp | nir_var_function_temp | nir_var_uniform |
nir_var_shader_call_data | nir_var_ray_hit_attrib;
nir_var_shader_call_data | nir_var_ray_hit_attrib |
nir_var_mem_task_payload;
assert(!(modes & ~supported) && "unsupported");
bool progress = false;
@@ -2402,6 +2418,8 @@ nir_lower_vars_to_explicit_types(nir_shader *shader,
progress |= lower_vars_to_explicit(shader, &shader->variables, nir_var_shader_call_data, type_info);
if (modes & nir_var_ray_hit_attrib)
progress |= lower_vars_to_explicit(shader, &shader->variables, nir_var_ray_hit_attrib, type_info);
if (modes & nir_var_mem_task_payload)
progress |= lower_vars_to_explicit(shader, &shader->variables, nir_var_mem_task_payload, type_info);
nir_foreach_function(function, shader) {
if (function->impl) {
@@ -2497,6 +2515,7 @@ nir_get_io_offset_src(nir_intrinsic_instr *instr)
case nir_intrinsic_load_input:
case nir_intrinsic_load_output:
case nir_intrinsic_load_shared:
case nir_intrinsic_load_task_payload:
case nir_intrinsic_load_uniform:
case nir_intrinsic_load_kernel_input:
case nir_intrinsic_load_global:
@@ -2541,6 +2560,7 @@ nir_get_io_offset_src(nir_intrinsic_instr *instr)
case nir_intrinsic_load_interpolated_input:
case nir_intrinsic_store_output:
case nir_intrinsic_store_shared:
case nir_intrinsic_store_task_payload:
case nir_intrinsic_store_global:
case nir_intrinsic_store_scratch:
case nir_intrinsic_ssbo_atomic_add:

View File

@@ -558,6 +558,8 @@ get_variable_mode_str(nir_variable_mode mode, bool want_local_global_mode)
return "shader_call_data";
case nir_var_ray_hit_attrib:
return "ray_hit_attrib";
case nir_var_mem_task_payload:
return "task_payload";
default:
return "";
}
@@ -1665,6 +1667,10 @@ nir_print_shader_annotated(nir_shader *shader, FILE *fp,
shader->info.workgroup_size_variable ? " (variable)" : "");
fprintf(fp, "shared-size: %u\n", shader->info.shared_size);
}
if (shader->info.stage == MESA_SHADER_MESH ||
shader->info.stage == MESA_SHADER_TASK) {
fprintf(fp, "task_payload-size: %u\n", shader->info.task_payload_size);
}
fprintf(fp, "inputs: %u\n", shader->num_inputs);
fprintf(fp, "outputs: %u\n", shader->num_outputs);

View File

@@ -1749,6 +1749,10 @@ nir_validate_shader(nir_shader *shader, const char *when)
shader->info.stage == MESA_SHADER_INTERSECTION)
valid_modes |= nir_var_ray_hit_attrib;
if (shader->info.stage == MESA_SHADER_TASK ||
shader->info.stage == MESA_SHADER_MESH)
valid_modes |= nir_var_mem_task_payload;
exec_list_validate(&shader->variables);
nir_foreach_variable_in_shader(var, shader)
validate_var_decl(var, valid_modes, &state);

View File

@@ -215,6 +215,11 @@ typedef struct shader_info {
*/
unsigned shared_size;
/**
* Size of task payload variables accessed by task/mesh shaders.
*/
unsigned task_payload_size;
/**
* Number of ray tracing queries in the shader (counts all elements of all
* variables).