@@ -21,7 +21,9 @@ SPDX-License-Identifier: MIT
21
21
#include < llvm/Analysis/ScalarEvolution.h>
22
22
#include < llvm/Analysis/ScalarEvolutionExpressions.h>
23
23
#include < llvm/Analysis/TargetFolder.h>
24
+ #include < llvm/Analysis/ValueTracking.h>
24
25
#include < llvm/IR/GetElementPtrTypeIterator.h>
26
+ #include < llvm/Support/KnownBits.h>
25
27
#include < llvm/Transforms/Utils/ScalarEvolutionExpander.h>
26
28
#include < llvm/Transforms/Utils/Local.h>
27
29
#include " llvmWrapper/IR/Intrinsics.h"
@@ -63,6 +65,7 @@ class GenIRLowering : public FunctionPass {
63
65
64
66
bool combineFMaxFMin (CallInst *GII, BasicBlock::iterator &BBI) const ;
65
67
bool combineSelectInst (SelectInst *Sel, BasicBlock::iterator &BBI) const ;
68
+ bool combinePack4i8Or2i16 (Instruction *inst, uint64_t numBits) const ;
66
69
67
70
bool constantFoldFMaxFMin (CallInst *GII, BasicBlock::iterator &BBI) const ;
68
71
};
@@ -362,6 +365,15 @@ bool GenIRLowering::runOnFunction(Function &F) {
362
365
Changed |= combineSelectInst (cast<SelectInst>(Inst), BI);
363
366
}
364
367
break ;
368
+ case Instruction::Or:
369
+ if (Inst->getType ()->isIntegerTy (32 )) {
370
+ // Detect packing of 4 i8 values and convert to a pattern that is
371
+ // matched CodeGenPatternMatch::MatchPack4i8
372
+ Changed |= combinePack4i8Or2i16 (Inst, 8 /* numBits*/ );
373
+ // TODO: also detect <2 x i16> packing once PatternMatch is updated
374
+ // to packing of 16-bit values.
375
+ }
376
+ break ;
365
377
}
366
378
}
367
379
}
@@ -1000,6 +1012,173 @@ bool GenIRLowering::combineSelectInst(SelectInst *Sel, BasicBlock::iterator &BBI
1000
1012
return false ;
1001
1013
}
1002
1014
1015
+ // //////////////////////////////////////////////////////////////////////////////
1016
+ // Detect complex patterns that pack 2 16-bit or 4 8-bit integers into a 32-bit
1017
+ // value. Generate equivalent sequence of instructions that is later matched in
1018
+ // the CodeGenPatternMatch::MatchPack4i8().
1019
+ // Pattern example for <4 x i8> packing:
1020
+ // %x1 = and i32 %x, 127
1021
+ // %x2 = lshr i32 %x, 24
1022
+ // %x3 = and i32 %x2, 128
1023
+ // %x4 = or i32 %x3, %x1
1024
+ // %y1 = and i32 %y, 127
1025
+ // %y2 = lshr i32 %y, 24
1026
+ // %y3 = and i32 %y2, 128
1027
+ // %y4 = or i32 %y3, %y1
1028
+ // %y5 = shl nuw nsw i32 %y4, 8
1029
+ // %xy = or i32 %x4, %y5
1030
+ // %z1 = and i32 %z, 127
1031
+ // %z2 = lshr i32 %z, 24
1032
+ // %z3 = and i32 %z2, 128
1033
+ // %z4 = or i32 %z3, %z1
1034
+ // %z5 = shl nuw nsw i32 %z4, 16
1035
+ // %xyz = or i32 %xy, %z5
1036
+ // %w1 = shl nsw i32 %w, 24
1037
+ // %w2 = and i32 %w1, 2130706432
1038
+ // %w3 = and i32 %w, -2147483648
1039
+ // %w4 = or i32 %w2, %w3
1040
+ // %xyzw = or i32 %xyz, %w4
1041
+ // and generate:
1042
+ // %0 = trunc i32 %x to i8
1043
+ // %1 = insertelement <4 x i8> poison, i8 %0, i32 0
1044
+ // %2 = trunc i32 %y to i8
1045
+ // %3 = insertelement <4 x i8> %1, i8 %2, i32 1
1046
+ // %4 = trunc i32 %z to i8
1047
+ // %5 = insertelement <4 x i8> %3, i8 %4, i32 2
1048
+ // %6 = trunc i32 %w to i8
1049
+ // %7 = insertelement <4 x i8> %5, i8 %6, i32 3
1050
+ // %8 = bitcast <4 x i8> %7 to i32
1051
+ bool GenIRLowering::combinePack4i8Or2i16 (Instruction *inst, uint64_t numBits) const {
1052
+ using namespace llvm ::PatternMatch;
1053
+
1054
+ const DataLayout &DL = inst->getModule ()->getDataLayout ();
1055
+ // Vector of 4 or 2 values that will be packed into a single 32-bit value.
1056
+ // The std::pair contains the 32-bit value that contains the element
1057
+ // to pack and the LSB where the packed value starts in the 32-bit value.
1058
+ SmallVector<std::pair<Value *, uint64_t >, 4 > toPack;
1059
+ IGC_ASSERT (numBits == 8 || numBits == 16 );
1060
+ uint64_t packedVecSize = 32 / numBits;
1061
+ toPack.resize (packedVecSize);
1062
+ uint64_t cSignMask = QWBIT (numBits - 1 );
1063
+ uint64_t cMagnMask = BITMASK (numBits - 1 );
1064
+ // The std::pair contains the 32-bit value that contains the element
1065
+ // to pack and the left shift bits that indicate the element position
1066
+ // in the packed vector.
1067
+ SmallVector<std::pair<Value *, uint64_t >, 4 > args;
1068
+ args.push_back ({isa<BitCastInst>(inst) ? inst->getOperand (0 ) : inst, 0 });
1069
+ // In the first step traverse the chain of `or` and `shl` instructions
1070
+ // and find all elements of the packed vector.
1071
+ while (!args.empty ()) {
1072
+ auto [v, prevShlBits] = args.pop_back_val ();
1073
+ Value *lOp = nullptr ;
1074
+ Value *rOp = nullptr ;
1075
+
1076
+ // Detect left shift by multiple of `numBits`. The `shl` operation sets the
1077
+ // `index` argument in the corresponding InsertElement instruction in the
1078
+ // final packing sequence. This operation can also be viewed as repacking
1079
+ // of already packed vector into another packed vector.
1080
+ uint64_t shlBits = 0 ;
1081
+ if (match (v, m_Shl (m_Value (lOp), m_ConstantInt (shlBits))) && (shlBits % numBits) == 0 ) {
1082
+ args.push_back ({lOp, shlBits + prevShlBits});
1083
+ continue ;
1084
+ }
1085
+ // Detect values that fit into `numBits` bits - a single element of
1086
+ // the packed vector.
1087
+ KnownBits kb = computeKnownBits (v, DL);
1088
+ uint32_t nonZeroBits = ~(static_cast <uint32_t >(kb.Zero .getZExtValue ()));
1089
+ uint32_t lsb = findFirstSet (nonZeroBits);
1090
+ uint32_t msb = findLastSet (nonZeroBits);
1091
+ if (msb != lsb && (msb / numBits) == (lsb / numBits)) {
1092
+ uint32_t idx = (prevShlBits / numBits) + (lsb / numBits);
1093
+ if (idx < packedVecSize && toPack[idx].first == nullptr ) {
1094
+ toPack[idx] = std::make_pair (v, alignDown (lsb, numBits));
1095
+ continue ;
1096
+ }
1097
+ }
1098
+ // Detect packing of two disjoint values. This `or` operation corresponds
1099
+ // to an InsertElement instruction in the final packing sequence.
1100
+ if (match (v, m_Or (m_Value (lOp), m_Value (rOp)))) {
1101
+ KnownBits kbL = computeKnownBits (lOp, DL);
1102
+ KnownBits kbR = computeKnownBits (rOp, DL);
1103
+ uint32_t nonZeroBitsL = ~(static_cast <uint32_t >(kbL.Zero .getZExtValue ()));
1104
+ uint32_t nonZeroBitsR = ~(static_cast <uint32_t >(kbR.Zero .getZExtValue ()));
1105
+ if ((nonZeroBitsL & nonZeroBitsR) == 0 ) {
1106
+ args.push_back ({lOp, prevShlBits});
1107
+ args.push_back ({rOp, prevShlBits});
1108
+ }
1109
+ continue ;
1110
+ }
1111
+ if (std::all_of (toPack.begin (), toPack.end (), [](const auto &c) { return c.first != nullptr ; })) {
1112
+ break ;
1113
+ }
1114
+ // Unsupported pattern.
1115
+ return false ;
1116
+ }
1117
+ if (std::any_of (toPack.begin (), toPack.end (), [](const auto &c) { return c.first == nullptr ; })) {
1118
+ return false ;
1119
+ }
1120
+ // In the second step match the pattern that packs sign and magnitude parts
1121
+ // and simple masking with `and` instruction.
1122
+ for (uint32_t i = 0 ; i < packedVecSize; ++i) {
1123
+ auto [v, lsb] = toPack[i];
1124
+ Value *lOp = nullptr ;
1125
+ Value *rOp = nullptr ;
1126
+ uint64_t lMask = 0 ;
1127
+ uint64_t rMask = 0 ;
1128
+ // Match patterns that pack the sign and magnitude parts.
1129
+ if (match (v, m_Or (m_And (m_Value (lOp), m_ConstantInt (lMask)), m_And (m_Value (rOp), m_ConstantInt (rMask)))) &&
1130
+ (countPopulation (rMask) == 1 || countPopulation (lMask) == 1 )) {
1131
+ Value *signOp = countPopulation (rMask) == 1 ? rOp : lOp;
1132
+ Value *magnOp = countPopulation (rMask) == 1 ? lOp : rOp;
1133
+ uint64_t signMask = countPopulation (rMask) == 1 ? rMask : lMask;
1134
+ uint64_t magnMask = countPopulation (rMask) == 1 ? lMask : rMask;
1135
+ uint64_t shlBits = 0 ;
1136
+ uint64_t shrBits = 0 ;
1137
+ // %b = shl nsw i32 %a, 24
1138
+ // %c = and i32 %b, 2130706432
1139
+ // %sign = and i32 %a, -2147483648
1140
+ // %e = or i32 %sign, %c
1141
+ if (match (magnOp, m_Shl (m_Value (v), m_ConstantInt (shlBits))) && v == signOp && (shlBits % numBits) == 0 &&
1142
+ shlBits == (i * numBits) && (cSignMask << shlBits) == signMask && (cMagnMask << shlBits) == magnMask &&
1143
+ lsb == shlBits) {
1144
+ toPack[i] = std::make_pair (v, 0 );
1145
+ continue ;
1146
+ }
1147
+ // %b = and i32 %a, 127
1148
+ // %c = lshr i32 %a, 24
1149
+ // %sign = and i32 %c, 128
1150
+ // %e = or i32 %sign, %b
1151
+ if (match (signOp, m_LShr (m_Value (v), m_ConstantInt (shrBits))) && v == magnOp && shrBits == (32 - numBits) &&
1152
+ cSignMask == signMask && cMagnMask == magnMask && lsb == 0 ) {
1153
+ toPack[i] = std::make_pair (v, 0 );
1154
+ continue ;
1155
+ }
1156
+ }
1157
+ uint64_t andMask = 0 ;
1158
+ if (match (v, m_And (m_Value (lOp), m_ConstantInt (andMask))) && (andMask & BITMASK (numBits)) == andMask && lsb == 0 ) {
1159
+ toPack[i] = std::make_pair (lOp, 0 );
1160
+ continue ;
1161
+ }
1162
+ if (lsb > 0 ) {
1163
+ return false ;
1164
+ }
1165
+ }
1166
+
1167
+ // Create the packing sequence that is matched in the PatternMatch later.
1168
+ Type *elemTy = Builder->getIntNTy (numBits);
1169
+ Value *packed = PoisonValue::get (IGCLLVM::FixedVectorType::get (elemTy, packedVecSize));
1170
+ for (uint32_t i = 0 ; i < packedVecSize; ++i) {
1171
+ auto [elem, lsb] = toPack[i];
1172
+ IGC_ASSERT (lsb == 0 );
1173
+ elem = Builder->CreateTrunc (elem, elemTy);
1174
+ packed = Builder->CreateInsertElement (packed, elem, Builder->getInt32 (i));
1175
+ }
1176
+ packed = Builder->CreateBitCast (packed, inst->getType ());
1177
+ inst->replaceAllUsesWith (packed);
1178
+ inst->eraseFromParent ();
1179
+ return true ;
1180
+ }
1181
+
1003
1182
FunctionPass *IGC::createGenIRLowerPass () { return new GenIRLowering (); }
1004
1183
1005
1184
// Register pass to igc-opt
0 commit comments