diff --git a/src/amd/compiler/aco_register_allocation.cpp b/src/amd/compiler/aco_register_allocation.cpp index a037745aaaf..ae96d02bdcb 100644 --- a/src/amd/compiler/aco_register_allocation.cpp +++ b/src/amd/compiler/aco_register_allocation.cpp @@ -495,6 +495,7 @@ get_subdword_operand_stride(chip_class chip, const aco_ptr& instr, case aco_opcode::ds_write_b16: return chip >= GFX9 ? 2 : 4; case aco_opcode::buffer_store_byte: case aco_opcode::buffer_store_short: + case aco_opcode::buffer_store_format_d16_x: case aco_opcode::flat_store_byte: case aco_opcode::flat_store_short: case aco_opcode::scratch_store_byte: @@ -552,6 +553,8 @@ add_subdword_operand(ra_ctx& ctx, aco_ptr& instr, unsigned idx, uns instr->opcode = aco_opcode::buffer_store_byte_d16_hi; else if (instr->opcode == aco_opcode::buffer_store_short) instr->opcode = aco_opcode::buffer_store_short_d16_hi; + else if (instr->opcode == aco_opcode::buffer_store_format_d16_x) + instr->opcode = aco_opcode::buffer_store_format_d16_hi_x; else if (instr->opcode == aco_opcode::flat_store_byte) instr->opcode = aco_opcode::flat_store_byte_d16_hi; else if (instr->opcode == aco_opcode::flat_store_short) @@ -601,6 +604,7 @@ get_subdword_definition_info(Program* program, const aco_ptr& instr } switch (instr->opcode) { + /* D16 loads with _hi version */ case aco_opcode::ds_read_u8_d16: case aco_opcode::ds_read_i8_d16: case aco_opcode::ds_read_u16_d16: @@ -615,16 +619,32 @@ get_subdword_definition_info(Program* program, const aco_ptr& instr case aco_opcode::scratch_load_short_d16: case aco_opcode::buffer_load_ubyte_d16: case aco_opcode::buffer_load_sbyte_d16: - case aco_opcode::buffer_load_short_d16: { + case aco_opcode::buffer_load_short_d16: + case aco_opcode::buffer_load_format_d16_x: { assert(chip >= GFX9); if (!program->dev.sram_ecc_enabled) return std::make_pair(2u, 2u); else return std::make_pair(2u, 4u); } - - default: return std::make_pair(4, rc.size() * 4u); + /* 3-component D16 loads */ + case aco_opcode::buffer_load_format_d16_xyz: + case aco_opcode::tbuffer_load_format_d16_xyz: { + assert(chip >= GFX9); + if (!program->dev.sram_ecc_enabled) + return std::make_pair(4u, 6u); + break; } + + default: break; + } + + if (instr->isMIMG() && instr->mimg().d16 && !program->dev.sram_ecc_enabled) { + assert(chip >= GFX9); + return std::make_pair(4u, rc.bytes()); + } + + return std::make_pair(4, rc.size() * 4u); } void @@ -667,6 +687,8 @@ add_subdword_definition(Program* program, aco_ptr& instr, PhysReg r instr->opcode = aco_opcode::buffer_load_sbyte_d16_hi; else if (instr->opcode == aco_opcode::buffer_load_short_d16) instr->opcode = aco_opcode::buffer_load_short_d16_hi; + else if (instr->opcode == aco_opcode::buffer_load_format_d16_x) + instr->opcode = aco_opcode::buffer_load_format_d16_hi_x; else if (instr->opcode == aco_opcode::flat_load_ubyte_d16) instr->opcode = aco_opcode::flat_load_ubyte_d16_hi; else if (instr->opcode == aco_opcode::flat_load_sbyte_d16) diff --git a/src/amd/compiler/aco_validate.cpp b/src/amd/compiler/aco_validate.cpp index 8c5a84ac3f9..445b4cd4918 100644 --- a/src/amd/compiler/aco_validate.cpp +++ b/src/amd/compiler/aco_validate.cpp @@ -256,8 +256,8 @@ validate_ir(Program* program) /* check subdword definitions */ for (unsigned i = 0; i < instr->definitions.size(); i++) { if (instr->definitions[i].regClass().is_subdword()) - check(instr->isPseudo() || instr->definitions[i].bytes() <= 4, - "Only Pseudo instructions can write subdword registers larger than 4 bytes", + check(instr->definitions[i].bytes() <= 4 || instr->isPseudo() || instr->isVMEM(), + "Only Pseudo and VMEM instructions can write subdword registers > 4 bytes", instr.get()); } @@ -542,6 +542,36 @@ validate_ir(Program* program) (instr->operands[3].isTemp() && instr->operands[3].regClass().type() == RegType::vgpr), "VMEM write data must be vgpr", instr.get()); + + const bool d16 = instr->opcode == aco_opcode::buffer_load_dword || // FIXME: used to spill subdword variables + instr->opcode == aco_opcode::buffer_load_ubyte || + instr->opcode == aco_opcode::buffer_load_sbyte || + instr->opcode == aco_opcode::buffer_load_ushort || + instr->opcode == aco_opcode::buffer_load_sshort || + instr->opcode == aco_opcode::buffer_load_ubyte_d16 || + instr->opcode == aco_opcode::buffer_load_ubyte_d16_hi || + instr->opcode == aco_opcode::buffer_load_sbyte_d16 || + instr->opcode == aco_opcode::buffer_load_sbyte_d16_hi || + instr->opcode == aco_opcode::buffer_load_short_d16 || + instr->opcode == aco_opcode::buffer_load_short_d16_hi || + instr->opcode == aco_opcode::buffer_load_format_d16_x || + instr->opcode == aco_opcode::buffer_load_format_d16_hi_x || + instr->opcode == aco_opcode::buffer_load_format_d16_xy || + instr->opcode == aco_opcode::buffer_load_format_d16_xyz || + instr->opcode == aco_opcode::buffer_load_format_d16_xyzw || + instr->opcode == aco_opcode::tbuffer_load_format_d16_x || + instr->opcode == aco_opcode::tbuffer_load_format_d16_xy || + instr->opcode == aco_opcode::tbuffer_load_format_d16_xyz || + instr->opcode == aco_opcode::tbuffer_load_format_d16_xyzw; + if (instr->definitions.size()) { + check(instr->definitions[0].isTemp() && + instr->definitions[0].regClass().type() == RegType::vgpr, + "VMEM definitions[0] (VDATA) must be VGPR", instr.get()); + check(d16 || !instr->definitions[0].regClass().is_subdword(), + "Only D16 opcodes can load subdword values.", instr.get()); + check(instr->definitions[0].bytes() <= 8 || !d16, + "D16 opcodes can only load up to 8 bytes.", instr.get()); + } break; } case Format::MIMG: { @@ -575,10 +605,16 @@ validate_ir(Program* program) instr.get()); } } - check(instr->definitions.empty() || - (instr->definitions[0].isTemp() && - instr->definitions[0].regClass().type() == RegType::vgpr), - "MIMG definitions[0] (VDATA) must be VGPR", instr.get()); + + if (instr->definitions.size()) { + check(instr->definitions[0].isTemp() && + instr->definitions[0].regClass().type() == RegType::vgpr, + "MIMG definitions[0] (VDATA) must be VGPR", instr.get()); + check(instr->mimg().d16 || !instr->definitions[0].regClass().is_subdword(), + "Only D16 MIMG instructions can load subdword values.", instr.get()); + check(instr->definitions[0].bytes() <= 8 || !instr->mimg().d16, + "D16 MIMG instructions can only load up to 8 bytes.", instr.get()); + } break; } case Format::DS: { @@ -744,6 +780,7 @@ validate_subdword_operand(chip_class chip, const aco_ptr& instr, un break; case aco_opcode::buffer_store_byte_d16_hi: case aco_opcode::buffer_store_short_d16_hi: + case aco_opcode::buffer_store_format_d16_hi_x: if (byte == 2 && index == 3) return true; break; @@ -778,7 +815,9 @@ validate_subdword_definition(chip_class chip, const aco_ptr& instr) switch (instr->opcode) { case aco_opcode::buffer_load_ubyte_d16_hi: + case aco_opcode::buffer_load_sbyte_d16_hi: case aco_opcode::buffer_load_short_d16_hi: + case aco_opcode::buffer_load_format_d16_hi_x: case aco_opcode::flat_load_ubyte_d16_hi: case aco_opcode::flat_load_short_d16_hi: case aco_opcode::scratch_load_ubyte_d16_hi: @@ -812,9 +851,17 @@ get_subdword_bytes_written(Program* program, const aco_ptr& instr, return 4; } + if (instr->isMIMG()) { + assert(instr->mimg().d16); + return program->dev.sram_ecc_enabled ? def.size() * 4u : def.bytes(); + } + switch (instr->opcode) { case aco_opcode::buffer_load_ubyte_d16: + case aco_opcode::buffer_load_sbyte_d16: case aco_opcode::buffer_load_short_d16: + case aco_opcode::buffer_load_format_d16_x: + case aco_opcode::tbuffer_load_format_d16_x: case aco_opcode::flat_load_ubyte_d16: case aco_opcode::flat_load_short_d16: case aco_opcode::scratch_load_ubyte_d16: @@ -824,7 +871,9 @@ get_subdword_bytes_written(Program* program, const aco_ptr& instr, case aco_opcode::ds_read_u8_d16: case aco_opcode::ds_read_u16_d16: case aco_opcode::buffer_load_ubyte_d16_hi: + case aco_opcode::buffer_load_sbyte_d16_hi: case aco_opcode::buffer_load_short_d16_hi: + case aco_opcode::buffer_load_format_d16_hi_x: case aco_opcode::flat_load_ubyte_d16_hi: case aco_opcode::flat_load_short_d16_hi: case aco_opcode::scratch_load_ubyte_d16_hi: @@ -833,6 +882,8 @@ get_subdword_bytes_written(Program* program, const aco_ptr& instr, case aco_opcode::global_load_short_d16_hi: case aco_opcode::ds_read_u8_d16_hi: case aco_opcode::ds_read_u16_d16_hi: return program->dev.sram_ecc_enabled ? 4 : 2; + case aco_opcode::buffer_load_format_d16_xyz: + case aco_opcode::tbuffer_load_format_d16_xyz: return program->dev.sram_ecc_enabled ? 8 : 6; default: return def.size() * 4; } }