Skip to content

Commit 6d2e0ce

Browse files
kyansitoigcbot
authored andcommitted
fix width of constant vectorization
allow choosing vector width depending on GRF size
1 parent 9907e77 commit 6d2e0ce

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

IGC/VectorCompiler/lib/GenXCodeGen/GenXPatternMatch.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3311,11 +3311,17 @@ bool GenXPatternMatch::distributeIntegerMul(Function *F) {
33113311
// where ShtAmt[0] is a constant vector and ShtAmt[i] are constant splats.
33123312
static bool analyzeForShiftPattern(Constant *C,
33133313
SmallVectorImpl<Constant *> &ShtAmt,
3314-
const DataLayout &DL) {
3315-
unsigned Width = 8;
3314+
const DataLayout &DL,
3315+
const llvm::GenXSubtarget &Subtarget) {
33163316
auto *VT = dyn_cast<IGCLLVM::FixedVectorType>(C->getType());
3317-
if (!VT || VT->getNumElements() <= Width || VT->getScalarSizeInBits() == 1)
3317+
if (!VT || VT->getScalarSizeInBits() == 1)
33183318
return false;
3319+
3320+
unsigned ElmSz = VT->getScalarSizeInBits() / genx::ByteBits;
3321+
unsigned Width = Subtarget.getGRFByteSize() / ElmSz;
3322+
if (cast<IGCLLVM::FixedVectorType>(VT)->getNumElements() <= Width)
3323+
return false;
3324+
33193325
unsigned NElts = VT->getNumElements();
33203326
if (NElts % Width != 0)
33213327
return false;
@@ -3388,6 +3394,9 @@ static bool analyzeForShiftPattern(Constant *C,
33883394
}
33893395

33903396
bool GenXPatternMatch::vectorizeConstants(Function *F) {
3397+
const GenXSubtarget *ST = &getAnalysis<TargetPassConfig>()
3398+
.getTM<GenXTargetMachine>()
3399+
.getGenXSubtarget();
33913400
bool Changed = false;
33923401
for (auto &BB : F->getBasicBlockList()) {
33933402
for (auto I = BB.begin(); I != BB.end();) {
@@ -3410,7 +3419,7 @@ bool GenXPatternMatch::vectorizeConstants(Function *F) {
34103419
C->getSplatValue())
34113420
continue;
34123421
SmallVector<Constant *, 8> ShtAmt;
3413-
if (analyzeForShiftPattern(C, ShtAmt, *DL)) {
3422+
if (analyzeForShiftPattern(C, ShtAmt, *DL, *ST)) {
34143423
// W1 = wrrregion(undef, ShtAmt[0], 0);
34153424
// V2 = fadd ShtAmt[0], ShtAmt[1]
34163425
// W2 = wrregion(W1, V2, Width)

0 commit comments

Comments
 (0)