Skip to content

Commit a6f59ee

Browse files
committed
[SCEV] Check if AddRec doesn't wrap via BTC before adding predicate.
#131281 exposed a case where SCEV is not able to infer NSW for an AddRec, but constant folding in SCEVExpander is able to determine the runtime check is always false (i.e. no NSW). This is caught by an assertion in LV, where we expand a runtime check and the trip count expression, but the runtime check gets folded away. For AddRecs with a step of 1, if Start + BTC >= Start, the AddRec is treated as having NUW/NSW and won't add a wrap predicate. https://alive2.llvm.org/ce/z/VnWwEN This check can help determine NSW/NUW in a few more cases, but doing so for all AddRecs has a noticeable compile time impact: https://llvm-compile-time-tracker.com/compare.php?from=215c0d2b651dc757378209a3edaff1a130338dd8&to=cdd1c1d32c598d77b73a57bcc05c1383786b3ac4&stat=instructions:u I am not sure if there is a good general place where we could try to refine wrap-flags in SCEV with logic like in the patch? Fixes #131281.
1 parent 0b47f6b commit a6f59ee

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14775,6 +14775,29 @@ const SCEVPredicate *ScalarEvolution::getWrapPredicate(
1477514775

1477614776
namespace {
1477714777

14778+
/// Return true if \p AR is known to not wrap via the loops backedge-taken count
14779+
/// \p BTC.
14780+
static bool proveNoWrapViaBTC(const SCEVAddRecExpr *AR,
14781+
SCEVWrapPredicate::IncrementWrapFlags Pred,
14782+
ScalarEvolution &SE) {
14783+
const Loop *L = AR->getLoop();
14784+
const SCEV *BTC = SE.getBackedgeTakenCount(L);
14785+
if (isa<SCEVCouldNotCompute>(BTC))
14786+
return false;
14787+
if (!match(AR->getStepRecurrence(SE), m_scev_One()) ||
14788+
AR->getType() != BTC->getType())
14789+
return false;
14790+
// AR has a step of 1, it is NSSW/NUSW if Start + BTC >= Start.
14791+
auto *Add = SE.getAddExpr(AR->getStart(), BTC);
14792+
assert((Pred == SCEVWrapPredicate::IncrementNSSW ||
14793+
Pred == SCEVWrapPredicate::IncrementNUSW) &&
14794+
"Unexpected predicate");
14795+
return SE.isKnownPredicate(Pred == SCEVWrapPredicate::IncrementNSSW
14796+
? CmpInst::ICMP_SGE
14797+
: CmpInst::ICMP_UGE,
14798+
Add, AR->getStart());
14799+
}
14800+
1477814801
class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
1477914802
public:
1478014803

@@ -14860,6 +14883,8 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
1486014883

1486114884
bool addOverflowAssumption(const SCEVAddRecExpr *AR,
1486214885
SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
14886+
if (proveNoWrapViaBTC(AR, AddedFlags, SE))
14887+
return true;
1486314888
auto *A = SE.getWrapPredicate(AR, AddedFlags);
1486414889
return addOverflowAssumption(A);
1486514890
}

llvm/test/Transforms/LoopVectorize/scev-predicate-reasoning.ll

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,58 @@ declare i1 @cond()
247247
; Test case for https://github.com/llvm/llvm-project/issues/131281.
248248
; %add2 is known to not wrap via BTC.
249249
define void @no_signed_wrap_iv_via_btc(ptr %dst, i32 %N) mustprogress {
250+
; CHECK-LABEL: define void @no_signed_wrap_iv_via_btc
251+
; CHECK-SAME: (ptr [[DST:%.*]], i32 [[N:%.*]]) #[[ATTR0:[0-9]+]] {
252+
; CHECK-NEXT: entry:
253+
; CHECK-NEXT: [[SUB:%.*]] = add i32 [[N]], -100
254+
; CHECK-NEXT: [[SUB4:%.*]] = add i32 [[N]], -99
255+
; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[N]], 1
256+
; CHECK-NEXT: [[SMAX:%.*]] = call i32 @llvm.smax.i32(i32 [[SUB4]], i32 [[TMP0]])
257+
; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[SMAX]], 100
258+
; CHECK-NEXT: [[TMP2:%.*]] = sub i32 [[TMP1]], [[N]]
259+
; CHECK-NEXT: br label [[OUTER:%.*]]
260+
; CHECK: outer.loopexit:
261+
; CHECK-NEXT: br label [[OUTER]]
262+
; CHECK: outer:
263+
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
264+
; CHECK-NEXT: br i1 [[C]], label [[LOOP_PREHEADER:%.*]], label [[EXIT:%.*]]
265+
; CHECK: loop.preheader:
266+
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[TMP2]], 4
267+
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
268+
; CHECK: vector.ph:
269+
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i32 [[TMP2]], 4
270+
; CHECK-NEXT: [[N_VEC:%.*]] = sub i32 [[TMP2]], [[N_MOD_VF]]
271+
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
272+
; CHECK: vector.body:
273+
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
274+
; CHECK-NEXT: [[TMP3:%.*]] = add i32 [[INDEX]], 0
275+
; CHECK-NEXT: [[TMP4:%.*]] = add i32 [[SUB4]], [[TMP3]]
276+
; CHECK-NEXT: [[TMP5:%.*]] = sext i32 [[TMP4]] to i64
277+
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i32, ptr [[DST]], i64 [[TMP5]]
278+
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i32, ptr [[TMP6]], i32 0
279+
; CHECK-NEXT: store <4 x i32> zeroinitializer, ptr [[TMP7]], align 4
280+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
281+
; CHECK-NEXT: [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
282+
; CHECK-NEXT: br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
283+
; CHECK: middle.block:
284+
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[TMP2]], [[N_VEC]]
285+
; CHECK-NEXT: br i1 [[CMP_N]], label [[OUTER_LOOPEXIT:%.*]], label [[SCALAR_PH]]
286+
; CHECK: scalar.ph:
287+
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[LOOP_PREHEADER]] ]
288+
; CHECK-NEXT: br label [[LOOP:%.*]]
289+
; CHECK: loop:
290+
; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[INC:%.*]], [[LOOP]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ]
291+
; CHECK-NEXT: [[ADD2:%.*]] = add i32 [[SUB4]], [[IV]]
292+
; CHECK-NEXT: [[ADD_EXT:%.*]] = sext i32 [[ADD2]] to i64
293+
; CHECK-NEXT: [[GEP_DST:%.*]] = getelementptr i32, ptr [[DST]], i64 [[ADD_EXT]]
294+
; CHECK-NEXT: store i32 0, ptr [[GEP_DST]], align 4
295+
; CHECK-NEXT: [[INC]] = add i32 [[IV]], 1
296+
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[SUB]], [[INC]]
297+
; CHECK-NEXT: [[EC:%.*]] = icmp sgt i32 [[ADD]], [[N]]
298+
; CHECK-NEXT: br i1 [[EC]], label [[OUTER_LOOPEXIT]], label [[LOOP]], !llvm.loop [[LOOP9:![0-9]+]]
299+
; CHECK: exit:
300+
; CHECK-NEXT: ret void
301+
;
250302
entry:
251303
%sub = add i32 %N, -100
252304
%sub4 = add i32 %N, -99

0 commit comments

Comments
 (0)