radv/nir: Add radv_nir_lower_hit_attrib_derefs

Move out the pass so it can be unit tested.

Reviewed-by: Friedrich Vock <friedrich.vock@gmx.de>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24271>
This commit is contained in:
Konstantin Seurer
2023-07-20 21:13:22 +02:00
committed by Marge Bot
parent b7c582e5c7
commit 3a69424e09
4 changed files with 123 additions and 106 deletions

View File

@@ -79,6 +79,7 @@ libradv_files = files(
'nir/radv_nir_lower_cooperative_matrix.c',
'nir/radv_nir_lower_fs_barycentric.c',
'nir/radv_nir_lower_fs_intrinsics.c',
'nir/radv_nir_lower_hit_attrib_derefs.c',
'nir/radv_nir_lower_intrinsics_early.c',
'nir/radv_nir_lower_io.c',
'nir/radv_nir_lower_poly_line_smooth.c',

View File

@@ -50,6 +50,8 @@ void radv_nir_lower_abi(nir_shader *shader, enum amd_gfx_level gfx_level, const
const struct radv_shader_args *args, const struct radv_pipeline_key *pl_key,
uint32_t address32_hi);
bool radv_nir_lower_hit_attrib_derefs(nir_shader *shader);
bool radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device);
bool radv_nir_lower_vs_inputs(nir_shader *shader, const struct radv_shader_stage *vs_stage,

View File

@@ -0,0 +1,116 @@
/*
* Copyright © 2021 Google
* Copyright © 2023 Valve Corporation
* SPDX-License-Identifier: MIT
*/
#include "nir.h"
#include "nir_builder.h"
#include "radv_nir.h"
static bool
lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
if (intrin->intrinsic != nir_intrinsic_load_deref && intrin->intrinsic != nir_intrinsic_store_deref)
return false;
nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
if (!nir_deref_mode_is(deref, nir_var_ray_hit_attrib))
return false;
assert(deref->deref_type == nir_deref_type_var);
b->cursor = nir_after_instr(instr);
if (intrin->intrinsic == nir_intrinsic_load_deref) {
uint32_t num_components = intrin->def.num_components;
uint32_t bit_size = intrin->def.bit_size;
nir_def *components[NIR_MAX_VEC_COMPONENTS];
for (uint32_t comp = 0; comp < num_components; comp++) {
uint32_t offset = deref->var->data.driver_location + comp * bit_size / 8;
uint32_t base = offset / 4;
uint32_t comp_offset = offset % 4;
if (bit_size == 64) {
components[comp] = nir_pack_64_2x32_split(b, nir_load_hit_attrib_amd(b, .base = base),
nir_load_hit_attrib_amd(b, .base = base + 1));
} else if (bit_size == 32) {
components[comp] = nir_load_hit_attrib_amd(b, .base = base);
} else if (bit_size == 16) {
components[comp] =
nir_channel(b, nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)), comp_offset / 2);
} else if (bit_size == 8) {
components[comp] =
nir_channel(b, nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8), comp_offset);
} else {
unreachable("Invalid bit_size");
}
}
nir_def_rewrite_uses(&intrin->def, nir_vec(b, components, num_components));
} else {
nir_def *value = intrin->src[1].ssa;
uint32_t num_components = value->num_components;
uint32_t bit_size = value->bit_size;
for (uint32_t comp = 0; comp < num_components; comp++) {
uint32_t offset = deref->var->data.driver_location + comp * bit_size / 8;
uint32_t base = offset / 4;
uint32_t comp_offset = offset % 4;
nir_def *component = nir_channel(b, value, comp);
if (bit_size == 64) {
nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_x(b, component), .base = base);
nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_y(b, component), .base = base + 1);
} else if (bit_size == 32) {
nir_store_hit_attrib_amd(b, component, .base = base);
} else if (bit_size == 16) {
nir_def *prev = nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base));
nir_def *components[2];
for (uint32_t word = 0; word < 2; word++)
components[word] = (word == comp_offset / 2) ? nir_channel(b, value, comp) : nir_channel(b, prev, word);
nir_store_hit_attrib_amd(b, nir_pack_32_2x16(b, nir_vec(b, components, 2)), .base = base);
} else if (bit_size == 8) {
nir_def *prev = nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8);
nir_def *components[4];
for (uint32_t byte = 0; byte < 4; byte++)
components[byte] = (byte == comp_offset) ? nir_channel(b, value, comp) : nir_channel(b, prev, byte);
nir_store_hit_attrib_amd(b, nir_pack_32_4x8(b, nir_vec(b, components, 4)), .base = base);
} else {
unreachable("Invalid bit_size");
}
}
}
nir_instr_remove(instr);
return true;
}
bool
radv_nir_lower_hit_attrib_derefs(nir_shader *shader)
{
bool progress = false;
progress |= nir_split_struct_vars(shader, nir_var_ray_hit_attrib);
progress |= nir_lower_indirect_derefs(shader, nir_var_ray_hit_attrib, UINT32_MAX);
progress |= nir_split_array_vars(shader, nir_var_ray_hit_attrib);
progress |= nir_lower_vars_to_explicit_types(shader, nir_var_ray_hit_attrib, glsl_get_natural_size_align_bytes);
progress |= nir_shader_instructions_pass(shader, lower_hit_attrib_deref,
nir_metadata_block_index | nir_metadata_dominance, NULL);
if (progress) {
nir_remove_dead_derefs(shader);
nir_remove_dead_variables(shader, nir_var_ray_hit_attrib, NULL);
}
return progress;
}

View File

@@ -26,6 +26,7 @@
#include "bvh/bvh.h"
#include "meta/radv_meta.h"
#include "nir/radv_nir.h"
#include "ac_nir.h"
#include "radv_private.h"
#include "radv_rt_common.h"
@@ -578,104 +579,6 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, bool apply_
nir_shader_instructions_pass(shader, radv_lower_rt_instruction, nir_metadata_none, &data);
}
static bool
lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
if (intrin->intrinsic != nir_intrinsic_load_deref && intrin->intrinsic != nir_intrinsic_store_deref)
return false;
nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
if (!nir_deref_mode_is(deref, nir_var_ray_hit_attrib))
return false;
assert(deref->deref_type == nir_deref_type_var);
b->cursor = nir_after_instr(instr);
if (intrin->intrinsic == nir_intrinsic_load_deref) {
uint32_t num_components = intrin->def.num_components;
uint32_t bit_size = intrin->def.bit_size;
nir_def *components[NIR_MAX_VEC_COMPONENTS];
for (uint32_t comp = 0; comp < num_components; comp++) {
uint32_t offset = deref->var->data.driver_location + comp * bit_size / 8;
uint32_t base = offset / 4;
uint32_t comp_offset = offset % 4;
if (bit_size == 64) {
components[comp] = nir_pack_64_2x32_split(b, nir_load_hit_attrib_amd(b, .base = base),
nir_load_hit_attrib_amd(b, .base = base + 1));
} else if (bit_size == 32) {
components[comp] = nir_load_hit_attrib_amd(b, .base = base);
} else if (bit_size == 16) {
components[comp] =
nir_channel(b, nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)), comp_offset / 2);
} else if (bit_size == 8) {
components[comp] =
nir_channel(b, nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8), comp_offset);
} else {
unreachable("Invalid bit_size");
}
}
nir_def_rewrite_uses(&intrin->def, nir_vec(b, components, num_components));
} else {
nir_def *value = intrin->src[1].ssa;
uint32_t num_components = value->num_components;
uint32_t bit_size = value->bit_size;
for (uint32_t comp = 0; comp < num_components; comp++) {
uint32_t offset = deref->var->data.driver_location + comp * bit_size / 8;
uint32_t base = offset / 4;
uint32_t comp_offset = offset % 4;
nir_def *component = nir_channel(b, value, comp);
if (bit_size == 64) {
nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_x(b, component), .base = base);
nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_y(b, component), .base = base + 1);
} else if (bit_size == 32) {
nir_store_hit_attrib_amd(b, component, .base = base);
} else if (bit_size == 16) {
nir_def *prev = nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base));
nir_def *components[2];
for (uint32_t word = 0; word < 2; word++)
components[word] = (word == comp_offset / 2) ? nir_channel(b, value, comp) : nir_channel(b, prev, word);
nir_store_hit_attrib_amd(b, nir_pack_32_2x16(b, nir_vec(b, components, 2)), .base = base);
} else if (bit_size == 8) {
nir_def *prev = nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8);
nir_def *components[4];
for (uint32_t byte = 0; byte < 4; byte++)
components[byte] = (byte == comp_offset) ? nir_channel(b, value, comp) : nir_channel(b, prev, byte);
nir_store_hit_attrib_amd(b, nir_pack_32_4x8(b, nir_vec(b, components, 4)), .base = base);
} else {
unreachable("Invalid bit_size");
}
}
}
nir_instr_remove(instr);
return true;
}
static bool
lower_hit_attrib_derefs(nir_shader *shader)
{
bool progress = nir_shader_instructions_pass(shader, lower_hit_attrib_deref,
nir_metadata_block_index | nir_metadata_dominance, NULL);
if (progress) {
nir_remove_dead_derefs(shader);
nir_remove_dead_variables(shader, nir_var_ray_hit_attrib, NULL);
}
return progress;
}
/* Lowers hit attributes to registers or shared memory. If hit_attribs is NULL, attributes are
* lowered to shared memory. */
static void
@@ -802,16 +705,11 @@ radv_parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreat
nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, key, false);
NIR_PASS(_, shader, nir_split_struct_vars, nir_var_ray_hit_attrib);
NIR_PASS(_, shader, nir_lower_indirect_derefs, nir_var_ray_hit_attrib, UINT32_MAX);
NIR_PASS(_, shader, nir_split_array_vars, nir_var_ray_hit_attrib);
NIR_PASS(_, shader, nir_lower_vars_to_explicit_types,
nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib,
NIR_PASS(_, shader, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data,
glsl_get_natural_size_align_bytes);
NIR_PASS(_, shader, lower_rt_derefs);
NIR_PASS(_, shader, lower_hit_attrib_derefs);
NIR_PASS(_, shader, radv_nir_lower_hit_attrib_derefs);
NIR_PASS(_, shader, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset);
@@ -1485,7 +1383,7 @@ radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipelin
radv_build_ray_traversal(device, b, &args);
nir_metadata_preserve(nir_shader_get_entrypoint(b->shader), nir_metadata_none);
lower_hit_attrib_derefs(b->shader);
radv_nir_lower_hit_attrib_derefs(b->shader);
/* Register storage for hit attributes */
nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_DWORDS];