Skip to content

Commit 7eb3d72

Browse files
committed
[SCEV] Distinguish between full and wrapping AddRec in proveNoWrapViaCR.
Try to widen integer AddRecs to a type one bit wider to distinguish between the AddRec wrapping or just hitting all possible values. Alternative to llvm#131538. Note that now we can end up in the awkward situation that we fail to compute an unpredicated BTC on the first try, but succeed on the second try, because we now have a accurate max BTC. For now, I updated getPredicatedBackedgeTakencCount to remove cached BTC if that happens, but perhaps there's a better solution?
1 parent b30d531 commit 7eb3d72

File tree

6 files changed

+102
-81
lines changed

6 files changed

+102
-81
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5092,9 +5092,29 @@ ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
50925092
}
50935093

50945094
if (!AR->hasNoUnsignedWrap()) {
5095-
ConstantRange AddRecRange = getUnsignedRange(AR);
5096-
ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5097-
5095+
const SCEVAddRecExpr *NewAR = AR;
5096+
unsigned BitWidth = getTypeSizeInBits(AR->getType());
5097+
// For integer AddRecs, try to evaluate the AddRec in a type one bit wider
5098+
// than the original type, to be able to differentiate between the AddRec
5099+
// hitting the full range or wrapping.
5100+
const SCEV *Step = AR->getStepRecurrence(*this);
5101+
if (AR->getType()->isIntegerTy() && isKnownNonNegative(Step)) {
5102+
Type *WiderTy = IntegerType::get(getContext(), BitWidth + 1);
5103+
NewAR = cast<SCEVAddRecExpr>(
5104+
getAddRecExpr(getSignExtendExpr(AR->getStart(), WiderTy),
5105+
getZeroExtendExpr(Step, WiderTy), AR->getLoop(),
5106+
AR->getNoWrapFlags()));
5107+
ConstantRange AddRecRange = getUnsignedRange(NewAR);
5108+
// If the wider AddRec range matches the original range after stripping
5109+
// the top bit, the original AddRec does not self-wrap.
5110+
if (AddRecRange !=
5111+
AddRecRange.truncate(BitWidth).zeroExtend(BitWidth + 1))
5112+
NewAR = AR;
5113+
}
5114+
ConstantRange AddRecRange = getUnsignedRange(NewAR);
5115+
ConstantRange IncRange = getUnsignedRange(Step);
5116+
if (NewAR != AR)
5117+
IncRange = IncRange.zeroExtend(getTypeSizeInBits(NewAR->getType()));
50985118
auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
50995119
Instruction::Add, IncRange, OBO::NoUnsignedWrap);
51005120
if (NUWRegion.contains(AddRecRange))
@@ -8336,7 +8356,13 @@ const SCEV *ScalarEvolution::getPredicatedExitCount(
83368356

83378357
const SCEV *ScalarEvolution::getPredicatedBackedgeTakenCount(
83388358
const Loop *L, SmallVectorImpl<const SCEVPredicate *> &Preds) {
8339-
return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8359+
auto *Res = getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8360+
if (!isa<SCEVCouldNotCompute>(Res) && Preds.empty()) {
8361+
auto I = BackedgeTakenCounts.find(L);
8362+
if (I != BackedgeTakenCounts.end() && !I->second.isComplete())
8363+
BackedgeTakenCounts.erase(I);
8364+
}
8365+
return Res;
83408366
}
83418367

83428368
const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L,
@@ -13858,7 +13884,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
1385813884
SmallVector<const SCEVPredicate *, 4> Preds;
1385913885
auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
1386013886
if (PBT != BTC) {
13861-
assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13887+
assert((!Preds.empty() || PBT == SE->getBackedgeTakenCount(L)) &&
13888+
"Different predicated BTC, but no predicates");
1386213889
OS << "Loop ";
1386313890
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
1386413891
OS << ": ";
@@ -13877,7 +13904,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
1387713904
auto *PredConstantMax =
1387813905
SE->getPredicatedConstantMaxBackedgeTakenCount(L, Preds);
1387913906
if (PredConstantMax != ConstantBTC) {
13880-
assert(!Preds.empty() &&
13907+
assert((!Preds.empty() ||
13908+
PredConstantMax == SE->getConstantMaxBackedgeTakenCount(L)) &&
1388113909
"different predicated constant max BTC but no predicates");
1388213910
OS << "Loop ";
1388313911
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
@@ -13897,7 +13925,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
1389713925
auto *PredSymbolicMax =
1389813926
SE->getPredicatedSymbolicMaxBackedgeTakenCount(L, Preds);
1389913927
if (SymbolicBTC != PredSymbolicMax) {
13900-
assert(!Preds.empty() &&
13928+
assert((!Preds.empty() ||
13929+
PredSymbolicMax == SE->getSymbolicMaxBackedgeTakenCount(L)) &&
1390113930
"Different predicated symbolic max BTC, but no predicates");
1390213931
OS << "Loop ";
1390313932
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);

llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,26 +153,22 @@ define void @ule_from_zero_no_nuw(i32 %M, i32 %N) {
153153
; CHECK-NEXT: exit count for loop: ***COULDNOTCOMPUTE***
154154
; CHECK-NEXT: predicated exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
155155
; CHECK-NEXT: Predicates:
156-
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
157156
; CHECK-EMPTY:
158157
; CHECK-NEXT: exit count for latch: %N
159158
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i32 -1
160159
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is %N
161160
; CHECK-NEXT: symbolic max exit count for loop: ***COULDNOTCOMPUTE***
162161
; CHECK-NEXT: predicated symbolic max exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
163162
; CHECK-NEXT: Predicates:
164-
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
165163
; CHECK-EMPTY:
166164
; CHECK-NEXT: symbolic max exit count for latch: %N
167165
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
168166
; CHECK-NEXT: Predicates:
169-
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
170167
; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4294967295
171168
; CHECK-NEXT: Predicates:
172-
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
173169
; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
174170
; CHECK-NEXT: Predicates:
175-
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
171+
; CHECK-NEXT: Loop %loop: Trip multiple is 1
176172
;
177173
entry:
178174
br label %loop
@@ -198,26 +194,22 @@ define void @le_from_zero_no_nuw(i32 %M, i32 %N) {
198194
; CHECK-NEXT: exit count for loop: ***COULDNOTCOMPUTE***
199195
; CHECK-NEXT: predicated exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
200196
; CHECK-NEXT: Predicates:
201-
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
202197
; CHECK-EMPTY:
203198
; CHECK-NEXT: exit count for latch: %N
204199
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i32 -1
205200
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is %N
206201
; CHECK-NEXT: symbolic max exit count for loop: ***COULDNOTCOMPUTE***
207202
; CHECK-NEXT: predicated symbolic max exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
208203
; CHECK-NEXT: Predicates:
209-
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
210204
; CHECK-EMPTY:
211205
; CHECK-NEXT: symbolic max exit count for latch: %N
212206
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
213207
; CHECK-NEXT: Predicates:
214-
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
215208
; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4294967295
216209
; CHECK-NEXT: Predicates:
217-
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
218210
; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
219211
; CHECK-NEXT: Predicates:
220-
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
212+
; CHECK-NEXT: Loop %loop: Trip multiple is 1
221213
;
222214
entry:
223215
br label %loop

0 commit comments

Comments
 (0)