Skip to content

Commit f6f28a7

Browse files
mmereckiigcbot
authored andcommitted
Another pattern match for packing <4 x i8> values.
This PR add detection for one more pattern that packs 4 8-bit integer values into a single 32-bit value.
1 parent 67faf6c commit f6f28a7

File tree

3 files changed

+201
-9
lines changed

3 files changed

+201
-9
lines changed

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2426,6 +2426,7 @@ void EmitPass::EmitPack4i8(const std::array<EOPCODE, 4> &opcodes, const std::arr
24262426
CVariable *src0 = GetSrcVariable(sources0[i]);
24272427
switch (opcodes[i]) {
24282428
case llvm_bitcast:
2429+
case llvm_fptosi:
24292430
m_encoder->Cast(dst, src0);
24302431
break;
24312432
case llvm_min:

IGC/Compiler/CISACodeGen/GenIRLowering.cpp

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ SPDX-License-Identifier: MIT
2121
#include <llvm/Analysis/ScalarEvolution.h>
2222
#include <llvm/Analysis/ScalarEvolutionExpressions.h>
2323
#include <llvm/Analysis/TargetFolder.h>
24+
#include <llvm/Analysis/ValueTracking.h>
2425
#include <llvm/IR/GetElementPtrTypeIterator.h>
26+
#include <llvm/Support/KnownBits.h>
2527
#include <llvm/Transforms/Utils/ScalarEvolutionExpander.h>
2628
#include <llvm/Transforms/Utils/Local.h>
2729
#include "llvmWrapper/IR/Intrinsics.h"
@@ -63,6 +65,7 @@ class GenIRLowering : public FunctionPass {
6365

6466
bool combineFMaxFMin(CallInst *GII, BasicBlock::iterator &BBI) const;
6567
bool combineSelectInst(SelectInst *Sel, BasicBlock::iterator &BBI) const;
68+
bool combinePack4i8Or2i16(Instruction *inst, uint64_t numBits) const;
6669

6770
bool constantFoldFMaxFMin(CallInst *GII, BasicBlock::iterator &BBI) const;
6871
};
@@ -362,6 +365,15 @@ bool GenIRLowering::runOnFunction(Function &F) {
362365
Changed |= combineSelectInst(cast<SelectInst>(Inst), BI);
363366
}
364367
break;
368+
case Instruction::Or:
369+
if (Inst->getType()->isIntegerTy(32)) {
370+
// Detect packing of 4 i8 values and convert to a pattern that is
371+
// matched CodeGenPatternMatch::MatchPack4i8
372+
Changed |= combinePack4i8Or2i16(Inst, 8 /*numBits*/);
373+
// TODO: also detect <2 x i16> packing once PatternMatch is updated
374+
// to packing of 16-bit values.
375+
}
376+
break;
365377
}
366378
}
367379
}
@@ -1000,6 +1012,173 @@ bool GenIRLowering::combineSelectInst(SelectInst *Sel, BasicBlock::iterator &BBI
10001012
return false;
10011013
}
10021014

1015+
////////////////////////////////////////////////////////////////////////////////
1016+
// Detect complex patterns that pack 2 16-bit or 4 8-bit integers into a 32-bit
1017+
// value. Generate equivalent sequence of instructions that is later matched in
1018+
// the CodeGenPatternMatch::MatchPack4i8().
1019+
// Pattern example for <4 x i8> packing:
1020+
// %x1 = and i32 %x, 127
1021+
// %x2 = lshr i32 %x, 24
1022+
// %x3 = and i32 %x2, 128
1023+
// %x4 = or i32 %x3, %x1
1024+
// %y1 = and i32 %y, 127
1025+
// %y2 = lshr i32 %y, 24
1026+
// %y3 = and i32 %y2, 128
1027+
// %y4 = or i32 %y3, %y1
1028+
// %y5 = shl nuw nsw i32 %y4, 8
1029+
// %xy = or i32 %x4, %y5
1030+
// %z1 = and i32 %z, 127
1031+
// %z2 = lshr i32 %z, 24
1032+
// %z3 = and i32 %z2, 128
1033+
// %z4 = or i32 %z3, %z1
1034+
// %z5 = shl nuw nsw i32 %z4, 16
1035+
// %xyz = or i32 %xy, %z5
1036+
// %w1 = shl nsw i32 %w, 24
1037+
// %w2 = and i32 %w1, 2130706432
1038+
// %w3 = and i32 %w, -2147483648
1039+
// %w4 = or i32 %w2, %w3
1040+
// %xyzw = or i32 %xyz, %w4
1041+
// and generate:
1042+
// %0 = trunc i32 %x to i8
1043+
// %1 = insertelement <4 x i8> poison, i8 %0, i32 0
1044+
// %2 = trunc i32 %y to i8
1045+
// %3 = insertelement <4 x i8> %1, i8 %2, i32 1
1046+
// %4 = trunc i32 %z to i8
1047+
// %5 = insertelement <4 x i8> %3, i8 %4, i32 2
1048+
// %6 = trunc i32 %w to i8
1049+
// %7 = insertelement <4 x i8> %5, i8 %6, i32 3
1050+
// %8 = bitcast <4 x i8> %7 to i32
1051+
bool GenIRLowering::combinePack4i8Or2i16(Instruction *inst, uint64_t numBits) const {
1052+
using namespace llvm::PatternMatch;
1053+
1054+
const DataLayout &DL = inst->getModule()->getDataLayout();
1055+
// Vector of 4 or 2 values that will be packed into a single 32-bit value.
1056+
// The std::pair contains the 32-bit value that contains the element
1057+
// to pack and the LSB where the packed value starts in the 32-bit value.
1058+
SmallVector<std::pair<Value *, uint64_t>, 4> toPack;
1059+
IGC_ASSERT(numBits == 8 || numBits == 16);
1060+
uint64_t packedVecSize = 32 / numBits;
1061+
toPack.resize(packedVecSize);
1062+
uint64_t cSignMask = QWBIT(numBits - 1);
1063+
uint64_t cMagnMask = BITMASK(numBits - 1);
1064+
// The std::pair contains the 32-bit value that contains the element
1065+
// to pack and the left shift bits that indicate the element position
1066+
// in the packed vector.
1067+
SmallVector<std::pair<Value *, uint64_t>, 4> args;
1068+
args.push_back({isa<BitCastInst>(inst) ? inst->getOperand(0) : inst, 0});
1069+
// In the first step traverse the chain of `or` and `shl` instructions
1070+
// and find all elements of the packed vector.
1071+
while (!args.empty()) {
1072+
auto [v, prevShlBits] = args.pop_back_val();
1073+
Value *lOp = nullptr;
1074+
Value *rOp = nullptr;
1075+
1076+
// Detect left shift by multiple of `numBits`. The `shl` operation sets the
1077+
// `index` argument in the corresponding InsertElement instruction in the
1078+
// final packing sequence. This operation can also be viewed as repacking
1079+
// of already packed vector into another packed vector.
1080+
uint64_t shlBits = 0;
1081+
if (match(v, m_Shl(m_Value(lOp), m_ConstantInt(shlBits))) && (shlBits % numBits) == 0) {
1082+
args.push_back({lOp, shlBits + prevShlBits});
1083+
continue;
1084+
}
1085+
// Detect values that fit into `numBits` bits - a single element of
1086+
// the packed vector.
1087+
KnownBits kb = computeKnownBits(v, DL);
1088+
uint32_t nonZeroBits = ~(static_cast<uint32_t>(kb.Zero.getZExtValue()));
1089+
uint32_t lsb = findFirstSet(nonZeroBits);
1090+
uint32_t msb = findLastSet(nonZeroBits);
1091+
if (msb != lsb && (msb / numBits) == (lsb / numBits)) {
1092+
uint32_t idx = (prevShlBits / numBits) + (lsb / numBits);
1093+
if (idx < packedVecSize && toPack[idx].first == nullptr) {
1094+
toPack[idx] = std::make_pair(v, alignDown(lsb, numBits));
1095+
continue;
1096+
}
1097+
}
1098+
// Detect packing of two disjoint values. This `or` operation corresponds
1099+
// to an InsertElement instruction in the final packing sequence.
1100+
if (match(v, m_Or(m_Value(lOp), m_Value(rOp)))) {
1101+
KnownBits kbL = computeKnownBits(lOp, DL);
1102+
KnownBits kbR = computeKnownBits(rOp, DL);
1103+
uint32_t nonZeroBitsL = ~(static_cast<uint32_t>(kbL.Zero.getZExtValue()));
1104+
uint32_t nonZeroBitsR = ~(static_cast<uint32_t>(kbR.Zero.getZExtValue()));
1105+
if ((nonZeroBitsL & nonZeroBitsR) == 0) {
1106+
args.push_back({lOp, prevShlBits});
1107+
args.push_back({rOp, prevShlBits});
1108+
}
1109+
continue;
1110+
}
1111+
if (std::all_of(toPack.begin(), toPack.end(), [](const auto &c) { return c.first != nullptr; })) {
1112+
break;
1113+
}
1114+
// Unsupported pattern.
1115+
return false;
1116+
}
1117+
if (std::any_of(toPack.begin(), toPack.end(), [](const auto &c) { return c.first == nullptr; })) {
1118+
return false;
1119+
}
1120+
// In the second step match the pattern that packs sign and magnitude parts
1121+
// and simple masking with `and` instruction.
1122+
for (uint32_t i = 0; i < packedVecSize; ++i) {
1123+
auto [v, lsb] = toPack[i];
1124+
Value *lOp = nullptr;
1125+
Value *rOp = nullptr;
1126+
uint64_t lMask = 0;
1127+
uint64_t rMask = 0;
1128+
// Match patterns that pack the sign and magnitude parts.
1129+
if (match(v, m_Or(m_And(m_Value(lOp), m_ConstantInt(lMask)), m_And(m_Value(rOp), m_ConstantInt(rMask)))) &&
1130+
(countPopulation(rMask) == 1 || countPopulation(lMask) == 1)) {
1131+
Value *signOp = countPopulation(rMask) == 1 ? rOp : lOp;
1132+
Value *magnOp = countPopulation(rMask) == 1 ? lOp : rOp;
1133+
uint64_t signMask = countPopulation(rMask) == 1 ? rMask : lMask;
1134+
uint64_t magnMask = countPopulation(rMask) == 1 ? lMask : rMask;
1135+
uint64_t shlBits = 0;
1136+
uint64_t shrBits = 0;
1137+
// %b = shl nsw i32 %a, 24
1138+
// %c = and i32 %b, 2130706432
1139+
// %sign = and i32 %a, -2147483648
1140+
// %e = or i32 %sign, %c
1141+
if (match(magnOp, m_Shl(m_Value(v), m_ConstantInt(shlBits))) && v == signOp && (shlBits % numBits) == 0 &&
1142+
shlBits == (i * numBits) && (cSignMask << shlBits) == signMask && (cMagnMask << shlBits) == magnMask &&
1143+
lsb == shlBits) {
1144+
toPack[i] = std::make_pair(v, 0);
1145+
continue;
1146+
}
1147+
// %b = and i32 %a, 127
1148+
// %c = lshr i32 %a, 24
1149+
// %sign = and i32 %c, 128
1150+
// %e = or i32 %sign, %b
1151+
if (match(signOp, m_LShr(m_Value(v), m_ConstantInt(shrBits))) && v == magnOp && shrBits == (32 - numBits) &&
1152+
cSignMask == signMask && cMagnMask == magnMask && lsb == 0) {
1153+
toPack[i] = std::make_pair(v, 0);
1154+
continue;
1155+
}
1156+
}
1157+
uint64_t andMask = 0;
1158+
if (match(v, m_And(m_Value(lOp), m_ConstantInt(andMask))) && (andMask & BITMASK(numBits)) == andMask && lsb == 0) {
1159+
toPack[i] = std::make_pair(lOp, 0);
1160+
continue;
1161+
}
1162+
if (lsb > 0) {
1163+
return false;
1164+
}
1165+
}
1166+
1167+
// Create the packing sequence that is matched in the PatternMatch later.
1168+
Type *elemTy = Builder->getIntNTy(numBits);
1169+
Value *packed = PoisonValue::get(IGCLLVM::FixedVectorType::get(elemTy, packedVecSize));
1170+
for (uint32_t i = 0; i < packedVecSize; ++i) {
1171+
auto [elem, lsb] = toPack[i];
1172+
IGC_ASSERT(lsb == 0);
1173+
elem = Builder->CreateTrunc(elem, elemTy);
1174+
packed = Builder->CreateInsertElement(packed, elem, Builder->getInt32(i));
1175+
}
1176+
packed = Builder->CreateBitCast(packed, inst->getType());
1177+
inst->replaceAllUsesWith(packed);
1178+
inst->eraseFromParent();
1179+
return true;
1180+
}
1181+
10031182
FunctionPass *IGC::createGenIRLowerPass() { return new GenIRLowering(); }
10041183

10051184
// Register pass to igc-opt

IGC/Compiler/CISACodeGen/PatternMatchPass.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2837,22 +2837,27 @@ bool CodeGenPatternMatch::MatchPack4i8(Instruction &I) {
28372837
}
28382838
return false;
28392839
};
2840-
// Lambda matches clamp(x, 0, 127) pattern.
2840+
// Lambda matches clamp(x, MIN, MAX) pattern.
28412841
// If the pattern is found `x` is returned in the `clampedVal` reference.
2842-
auto MatchClamp0_127 = [&MatchMinMaxWithImm](Value *v, Value *&clampedVal) -> bool {
2842+
auto MatchClampWithImm = [&MatchMinMaxWithImm](Value *v, Value *&clampedVal, uint32_t minVal, uint32_t maxVal) -> bool {
28432843
bool matchMin = true;
28442844
bool matchMax = false;
28452845
Value *src[2];
28462846
// Match either of:
2847-
// v = min(max(x, 0), 127)
2848-
// v = max(min(x, 127), 0)
2849-
if ((MatchMinMaxWithImm(v, 127, matchMin, src[0]) && MatchMinMaxWithImm(src[0], 0, matchMax, src[1])) ||
2850-
(MatchMinMaxWithImm(v, 0, matchMax, src[0]) && MatchMinMaxWithImm(src[0], 127, matchMin, src[1]))) {
2847+
// v = min(max(x, MIN), MAX)
2848+
// v = max(min(x, MIN), MAX)
2849+
if ((MatchMinMaxWithImm(v, maxVal, matchMin, src[0]) && MatchMinMaxWithImm(src[0], minVal, matchMax, src[1])) ||
2850+
(MatchMinMaxWithImm(v, minVal, matchMax, src[0]) && MatchMinMaxWithImm(src[0], maxVal, matchMin, src[1]))) {
28512851
clampedVal = src[1];
28522852
return true;
28532853
}
28542854
return false;
28552855
};
2856+
// Lambda matches clamp(x, 0, 127) pattern.
2857+
// If the pattern is found `x` is returned in the `clampedVal` reference.
2858+
auto MatchClamp0_127 = [&MatchClampWithImm](Value *v, Value *&clampedVal) -> bool {
2859+
return MatchClampWithImm(v, clampedVal, 0, 127);
2860+
};
28562861

28572862
EOPCODE opcodes[4] = {};
28582863
Value *sources0[4] = {};
@@ -2897,16 +2902,23 @@ bool CodeGenPatternMatch::MatchPack4i8(Instruction &I) {
28972902
}
28982903
if (elemsFound == 4) {
28992904
// Match pattern 2
2900-
// Match clamping of values to 0..127 range, e.g.:
2901-
// %x1 = max i32 %x0, 0
2902-
// %x2 = min i32 %x1, 127
29032905
for (uint32_t i = 0; i < 4; ++i) {
29042906
Value *srcToSat;
2907+
// Match clamping of values to 0..127 range, e.g.:
2908+
// %x1 = max i32 %x0, 0
2909+
// %x2 = min i32 %x1, 127
29052910
if (MatchClamp0_127(sources0[i], srcToSat)) {
29062911
opcodes[i] = llvm_max;
29072912
sources0[i] = srcToSat;
29082913
sources1[i] = ConstantInt::get(srcToSat->getType(), 0);
29092914
isSat[i] = true;
2915+
// Match clamping of values to -128..127 range, e.g.:
2916+
// %x1 = max i32 %x0, -128
2917+
// %x2 = min i32 %x1, 127
2918+
} else if (MatchClampWithImm(sources0[i], srcToSat, -128, 127)) {
2919+
opcodes[i] = llvm_fptosi;
2920+
sources0[i] = srcToSat;
2921+
isSat[i] = true;
29102922
}
29112923
}
29122924
}

0 commit comments

Comments
 (0)