diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c index c917b11c643..f480a57f3f3 100644 --- a/src/amd/vulkan/radv_pipeline.c +++ b/src/amd/vulkan/radv_pipeline.c @@ -2948,11 +2948,15 @@ mem_vectorize_callback(unsigned align_mul, unsigned align_offset, } static unsigned -lower_bit_size_callback(const nir_alu_instr *alu, void *_) +lower_bit_size_callback(const nir_instr *instr, void *_) { struct radv_device *device = _; enum chip_class chip = device->physical_device->rad_info.chip_class; + if (instr->type != nir_instr_type_alu) + return 0; + nir_alu_instr *alu = nir_instr_as_alu(instr); + if (alu->dest.dest.ssa.bit_size & (8 | 16)) { unsigned bit_size = alu->dest.dest.ssa.bit_size; switch (alu->op) { diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 389c58587b3..9365c163773 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -4948,7 +4948,7 @@ typedef enum { bool nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags options); -typedef unsigned (*nir_lower_bit_size_callback)(const nir_alu_instr *, void *); +typedef unsigned (*nir_lower_bit_size_callback)(const nir_instr *, void *); bool nir_lower_bit_size(nir_shader *shader, nir_lower_bit_size_callback callback, diff --git a/src/compiler/nir/nir_lower_bit_size.c b/src/compiler/nir/nir_lower_bit_size.c index 0508bdd3d87..a53090a8760 100644 --- a/src/compiler/nir/nir_lower_bit_size.c +++ b/src/compiler/nir/nir_lower_bit_size.c @@ -46,7 +46,7 @@ static nir_ssa_def *convert_to_bit_size(nir_builder *bld, nir_ssa_def *src, } static void -lower_instr(nir_builder *bld, nir_alu_instr *alu, unsigned bit_size) +lower_alu_instr(nir_builder *bld, nir_alu_instr *alu, unsigned bit_size) { const nir_op op = alu->op; unsigned dst_bit_size = alu->dest.dest.ssa.bit_size; @@ -109,14 +109,11 @@ lower_impl(nir_function_impl *impl, if (instr->type != nir_instr_type_alu) continue; - nir_alu_instr *alu = nir_instr_as_alu(instr); - assert(alu->dest.dest.is_ssa); - - unsigned lower_bit_size = callback(alu, callback_data); + unsigned lower_bit_size = callback(instr, callback_data); if (lower_bit_size == 0) continue; - lower_instr(&b, alu, lower_bit_size); + lower_alu_instr(&b, nir_instr_as_alu(instr), lower_bit_size); progress = true; } } diff --git a/src/intel/compiler/brw_nir.c b/src/intel/compiler/brw_nir.c index 245003c8dbb..61c1ef98338 100644 --- a/src/intel/compiler/brw_nir.c +++ b/src/intel/compiler/brw_nir.c @@ -633,8 +633,12 @@ brw_nir_optimize(nir_shader *nir, const struct brw_compiler *compiler, } static unsigned -lower_bit_size_callback(const nir_alu_instr *alu, UNUSED void *data) +lower_bit_size_callback(const nir_instr *instr, UNUSED void *data) { + if (instr->type != nir_instr_type_alu) + return 0; + + nir_alu_instr *alu = nir_instr_as_alu(instr); assert(alu->dest.dest.is_ssa); if (alu->dest.dest.ssa.bit_size >= 32) return 0;