@@ -5092,9 +5092,29 @@ ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5092
5092
}
5093
5093
5094
5094
if (!AR->hasNoUnsignedWrap()) {
5095
- ConstantRange AddRecRange = getUnsignedRange(AR);
5096
- ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5097
-
5095
+ const SCEVAddRecExpr *NewAR = AR;
5096
+ unsigned BitWidth = getTypeSizeInBits(AR->getType());
5097
+ // For integer AddRecs, try to evaluate the AddRec in a type one bit wider
5098
+ // than the original type, to be able to differentiate between the AddRec
5099
+ // hitting the full range or wrapping.
5100
+ const SCEV *Step = AR->getStepRecurrence(*this);
5101
+ if (AR->getType()->isIntegerTy() && isKnownNonNegative(Step)) {
5102
+ Type *WiderTy = IntegerType::get(getContext(), BitWidth + 1);
5103
+ NewAR = cast<SCEVAddRecExpr>(
5104
+ getAddRecExpr(getSignExtendExpr(AR->getStart(), WiderTy),
5105
+ getZeroExtendExpr(Step, WiderTy), AR->getLoop(),
5106
+ AR->getNoWrapFlags()));
5107
+ ConstantRange AddRecRange = getUnsignedRange(NewAR);
5108
+ // If the wider AddRec range matches the original range after stripping
5109
+ // the top bit, the original AddRec does not self-wrap.
5110
+ if (AddRecRange !=
5111
+ AddRecRange.truncate(BitWidth).zeroExtend(BitWidth + 1))
5112
+ NewAR = AR;
5113
+ }
5114
+ ConstantRange AddRecRange = getUnsignedRange(NewAR);
5115
+ ConstantRange IncRange = getUnsignedRange(Step);
5116
+ if (NewAR != AR)
5117
+ IncRange = IncRange.zeroExtend(getTypeSizeInBits(NewAR->getType()));
5098
5118
auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
5099
5119
Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5100
5120
if (NUWRegion.contains(AddRecRange))
@@ -8336,7 +8356,13 @@ const SCEV *ScalarEvolution::getPredicatedExitCount(
8336
8356
8337
8357
const SCEV *ScalarEvolution::getPredicatedBackedgeTakenCount(
8338
8358
const Loop *L, SmallVectorImpl<const SCEVPredicate *> &Preds) {
8339
- return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8359
+ auto *Res = getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8360
+ if (!isa<SCEVCouldNotCompute>(Res) && Preds.empty()) {
8361
+ auto I = BackedgeTakenCounts.find(L);
8362
+ if (I != BackedgeTakenCounts.end() && !I->second.isComplete())
8363
+ BackedgeTakenCounts.erase(I);
8364
+ }
8365
+ return Res;
8340
8366
}
8341
8367
8342
8368
const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L,
@@ -13858,7 +13884,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
13858
13884
SmallVector<const SCEVPredicate *, 4> Preds;
13859
13885
auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13860
13886
if (PBT != BTC) {
13861
- assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13887
+ assert((!Preds.empty() || PBT == SE->getBackedgeTakenCount(L)) &&
13888
+ "Different predicated BTC, but no predicates");
13862
13889
OS << "Loop ";
13863
13890
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13864
13891
OS << ": ";
@@ -13877,7 +13904,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
13877
13904
auto *PredConstantMax =
13878
13905
SE->getPredicatedConstantMaxBackedgeTakenCount(L, Preds);
13879
13906
if (PredConstantMax != ConstantBTC) {
13880
- assert(!Preds.empty() &&
13907
+ assert((!Preds.empty() ||
13908
+ PredConstantMax == SE->getConstantMaxBackedgeTakenCount(L)) &&
13881
13909
"different predicated constant max BTC but no predicates");
13882
13910
OS << "Loop ";
13883
13911
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
@@ -13897,7 +13925,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
13897
13925
auto *PredSymbolicMax =
13898
13926
SE->getPredicatedSymbolicMaxBackedgeTakenCount(L, Preds);
13899
13927
if (SymbolicBTC != PredSymbolicMax) {
13900
- assert(!Preds.empty() &&
13928
+ assert((!Preds.empty() ||
13929
+ PredSymbolicMax == SE->getSymbolicMaxBackedgeTakenCount(L)) &&
13901
13930
"Different predicated symbolic max BTC, but no predicates");
13902
13931
OS << "Loop ";
13903
13932
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
0 commit comments