Skip to content

Commit fe453aa

Browse files
fhahngithub-actions[bot]
authored andcommitted
Automerge: [VPlan] Keep common flags during CSE. (#157664)
During CSE, we don't have to drop all poison-generating flags on mis-match, we can keep the ones common on both recipes. PR: llvm/llvm-project#157664
2 parents ba97e33 + c3e76b2 commit fe453aa

File tree

6 files changed

+516
-8
lines changed

6 files changed

+516
-8
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,10 @@ class VPIRFlags {
721721
AllFlags = Other.AllFlags;
722722
}
723723

724+
/// Only keep flags also present in \p Other. \p Other must have the same
725+
/// OpType as the current object.
726+
void intersectFlags(const VPIRFlags &Other);
727+
724728
/// Drop all poison-generating flags.
725729
void dropPoisonGeneratingFlags() {
726730
// NOTE: This needs to be kept in-sync with

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,42 @@ void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
392392
}
393393
#endif
394394

395+
void VPIRFlags::intersectFlags(const VPIRFlags &Other) {
396+
assert(OpType == Other.OpType && "OpType must match");
397+
switch (OpType) {
398+
case OperationType::OverflowingBinOp:
399+
WrapFlags.HasNUW &= Other.WrapFlags.HasNUW;
400+
WrapFlags.HasNSW &= Other.WrapFlags.HasNSW;
401+
break;
402+
case OperationType::Trunc:
403+
TruncFlags.HasNUW &= Other.TruncFlags.HasNUW;
404+
TruncFlags.HasNSW &= Other.TruncFlags.HasNSW;
405+
break;
406+
case OperationType::DisjointOp:
407+
DisjointFlags.IsDisjoint &= Other.DisjointFlags.IsDisjoint;
408+
break;
409+
case OperationType::PossiblyExactOp:
410+
ExactFlags.IsExact &= Other.ExactFlags.IsExact;
411+
break;
412+
case OperationType::GEPOp:
413+
GEPFlags &= Other.GEPFlags;
414+
break;
415+
case OperationType::FPMathOp:
416+
FMFs.NoNaNs &= Other.FMFs.NoNaNs;
417+
FMFs.NoInfs &= Other.FMFs.NoInfs;
418+
break;
419+
case OperationType::NonNegOp:
420+
NonNegFlags.NonNeg &= Other.NonNegFlags.NonNeg;
421+
break;
422+
case OperationType::Cmp:
423+
assert(CmpPredicate == Other.CmpPredicate && "Cannot drop CmpPredicate");
424+
break;
425+
case OperationType::Other:
426+
assert(AllFlags == Other.AllFlags && "Cannot drop other flags");
427+
break;
428+
}
429+
}
430+
395431
FastMathFlags VPIRFlags::getFastMathFlags() const {
396432
assert(OpType == OperationType::FPMathOp &&
397433
"recipe doesn't have fast math flags");

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2043,9 +2043,9 @@ void VPlanTransforms::cse(VPlan &Plan) {
20432043
// V must dominate Def for a valid replacement.
20442044
if (!VPDT.dominates(V->getParent(), VPBB))
20452045
continue;
2046-
// Drop poison-generating flags when reusing a value.
2046+
// Only keep flags present on both V and Def.
20472047
if (auto *RFlags = dyn_cast<VPRecipeWithIRFlags>(V))
2048-
RFlags->dropPoisonGeneratingFlags();
2048+
RFlags->intersectFlags(*cast<VPRecipeWithIRFlags>(Def));
20492049
Def->replaceAllUsesWith(V);
20502050
continue;
20512051
}

llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ define dso_local void @test(ptr %Arr, i32 signext %Len) {
1616
; CHECK: vector.body:
1717
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
1818
; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[INDEX]] to i64
19-
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[ARR:%.*]], i64 [[TMP1]]
19+
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[ARR:%.*]], i64 [[TMP1]]
2020
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP2]], align 4
2121
; CHECK-NEXT: [[TMP4:%.*]] = call <4 x i32> @llvm.bswap.v4i32(<4 x i32> [[WIDE_LOAD]])
2222
; CHECK-NEXT: store <4 x i32> [[TMP4]], ptr [[TMP2]], align 4

llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ define void @_Z3fn1v() #0 {
142142
; CHECK-NEXT: [[TMP32:%.*]] = add nsw <16 x i64> [[TMP30]], [[VEC_IND37]]
143143
; CHECK-NEXT: [[TMP33:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP32]], i64 0
144144
; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[TMP34]])
145-
; CHECK-NEXT: [[TMP49:%.*]] = or <16 x i64> [[VEC_IND37]], splat (i64 1)
146-
; CHECK-NEXT: [[TMP36:%.*]] = add <16 x i64> [[TMP30]], [[TMP49]]
145+
; CHECK-NEXT: [[TMP49:%.*]] = or disjoint <16 x i64> [[VEC_IND37]], splat (i64 1)
146+
; CHECK-NEXT: [[TMP36:%.*]] = add nsw <16 x i64> [[TMP30]], [[TMP49]]
147147
; CHECK-NEXT: [[TMP37:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP36]], i64 0
148148
; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP37]], i32 8, <16 x i1> [[TMP34]])
149149
; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 7), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[BROADCAST_SPLAT]])
@@ -191,8 +191,8 @@ define void @_Z3fn1v() #0 {
191191
; CHECK-NEXT: [[TMP46:%.*]] = add nsw <8 x i64> [[TMP44]], [[VEC_IND70]]
192192
; CHECK-NEXT: [[TMP47:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP46]], i64 0
193193
; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[TMP48]])
194-
; CHECK-NEXT: [[TMP54:%.*]] = or <8 x i64> [[VEC_IND70]], splat (i64 1)
195-
; CHECK-NEXT: [[TMP50:%.*]] = add <8 x i64> [[TMP44]], [[TMP54]]
194+
; CHECK-NEXT: [[TMP54:%.*]] = or disjoint <8 x i64> [[VEC_IND70]], splat (i64 1)
195+
; CHECK-NEXT: [[TMP50:%.*]] = add nsw <8 x i64> [[TMP44]], [[TMP54]]
196196
; CHECK-NEXT: [[TMP51:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP50]], i64 0
197197
; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP51]], i32 8, <8 x i1> [[TMP48]])
198198
; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 7), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[BROADCAST_SPLAT73]])

0 commit comments

Comments
 (0)