diff --git a/src/microsoft/compiler/dxil_nir.c b/src/microsoft/compiler/dxil_nir.c index fe25f23552b..97ad6f1cf67 100644 --- a/src/microsoft/compiler/dxil_nir.c +++ b/src/microsoft/compiler/dxil_nir.c @@ -1230,3 +1230,66 @@ dxil_nir_split_clip_cull_distance(nir_shader *shader) return new_var != NULL; } + +bool +dxil_nir_lower_double_math(nir_shader *shader) +{ + bool progress = false; + nir_foreach_function(func, shader) { + bool func_progress = false; + if (!func->impl) + continue; + + nir_builder b; + nir_builder_init(&b, func->impl); + nir_foreach_block(block, func->impl) { + nir_foreach_instr_safe(instr, block) { + if (instr->type != nir_instr_type_alu) + continue; + + nir_alu_instr *alu = nir_instr_as_alu(instr); + + /* TODO: See if we can apply this explicitly to packs/unpacks that are then + * used as a double. As-is, if we had an app explicitly do a 64bit integer op, + * then try to bitcast to double (not expressible in HLSL, but it is in other + * source languages), this would unpack the integer and repack as a double, when + * we probably want to just send the bitcast through to the backend. + */ + + b.cursor = nir_before_instr(&alu->instr); + + for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i) { + if (nir_alu_type_get_base_type(nir_op_infos[alu->op].input_types[i]) == nir_type_float && + alu->src[i].src.ssa->bit_size == 64) { + nir_ssa_def *packed_double = nir_channel(&b, alu->src[i].src.ssa, alu->src[i].swizzle[0]); + nir_ssa_def *unpacked_double = nir_unpack_64_2x32(&b, packed_double); + nir_ssa_def *repacked_double = nir_pack_double_2x32_dxil(&b, unpacked_double); + nir_instr_rewrite_src_ssa(instr, &alu->src[i].src, repacked_double); + memset(alu->src[i].swizzle, 0, ARRAY_SIZE(alu->src[i].swizzle)); + func_progress = true; + } + } + + if (nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type) == nir_type_float && + alu->dest.dest.ssa.bit_size == 64) { + b.cursor = nir_after_instr(&alu->instr); + nir_ssa_def *packed_double = &alu->dest.dest.ssa; + nir_ssa_def *unpacked_double = nir_unpack_double_2x32_dxil(&b, packed_double); + nir_ssa_def *repacked_double = nir_pack_64_2x32(&b, unpacked_double); + nir_ssa_def_rewrite_uses_after(packed_double, repacked_double, unpacked_double->parent_instr); + func_progress = true; + } + } + } + + if (func_progress) + nir_metadata_preserve(func->impl, nir_metadata_block_index | + nir_metadata_dominance | + nir_metadata_loop_analysis); + else + nir_metadata_preserve(func->impl, nir_metadata_all); + progress |= func_progress; + } + + return progress; +} diff --git a/src/microsoft/compiler/dxil_nir.h b/src/microsoft/compiler/dxil_nir.h index 6c11d117d78..fe657efdd0c 100644 --- a/src/microsoft/compiler/dxil_nir.h +++ b/src/microsoft/compiler/dxil_nir.h @@ -45,6 +45,7 @@ bool dxil_nir_lower_memcpy_deref(nir_shader *shader); bool dxil_nir_lower_upcast_phis(nir_shader *shader, unsigned min_bit_size); bool dxil_nir_lower_fp16_casts(nir_shader *shader); bool dxil_nir_split_clip_cull_distance(nir_shader *shader); +bool dxil_nir_lower_double_math(nir_shader *shader); nir_ssa_def * build_load_ubo_dxil(nir_builder *b, nir_ssa_def *buffer,