-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[VPlan] Keep common flags during CSE. #157664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
During CSE, we don't have to drop poison-generating flags, if both the re-used recipe and the to-be-replaced recipe have the same flags.
@llvm/pr-subscribers-vectorizers @llvm/pr-subscribers-llvm-transforms Author: Florian Hahn (fhahn) ChangesDuring CSE, we don't have to drop all poison-generating flags on mis-match, we can keep the ones common on both recipes. Full diff: https://github.com/llvm/llvm-project/pull/157664.diff 6 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index b93bdf244237e..53291a931530f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -721,6 +721,10 @@ class VPIRFlags {
AllFlags = Other.AllFlags;
}
+ /// Only keep flags also present in \p Other. \p Other must have the same
+ /// OpType as the current object.
+ void intersectFlags(const VPIRFlags &Other);
+
/// Drop all poison-generating flags.
void dropPoisonGeneratingFlags() {
// NOTE: This needs to be kept in-sync with
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 46162a9276469..9f1311fbd0687 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -392,6 +392,42 @@ void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
}
#endif
+void VPIRFlags::intersectFlags(const VPIRFlags &Other) {
+ assert(OpType == Other.OpType && "OpType must match");
+ switch (OpType) {
+ case OperationType::OverflowingBinOp:
+ WrapFlags.HasNUW &= Other.WrapFlags.HasNUW;
+ WrapFlags.HasNSW &= Other.WrapFlags.HasNSW;
+ break;
+ case OperationType::Trunc:
+ TruncFlags.HasNUW &= Other.TruncFlags.HasNUW;
+ TruncFlags.HasNSW &= Other.TruncFlags.HasNSW;
+ break;
+ case OperationType::DisjointOp:
+ DisjointFlags.IsDisjoint &= Other.DisjointFlags.IsDisjoint;
+ break;
+ case OperationType::PossiblyExactOp:
+ ExactFlags.IsExact = Other.ExactFlags.IsExact;
+ break;
+ case OperationType::GEPOp:
+ GEPFlags &= Other.GEPFlags;
+ break;
+ case OperationType::FPMathOp:
+ FMFs.NoNaNs &= Other.FMFs.NoNaNs;
+ FMFs.NoInfs &= Other.FMFs.NoInfs;
+ break;
+ case OperationType::NonNegOp:
+ NonNegFlags.NonNeg &= Other.NonNegFlags.NonNeg;
+ break;
+ case OperationType::Cmp:
+ assert(CmpPredicate == Other.CmpPredicate && "Cannot drop CmpPredicate");
+ break;
+ case OperationType::Other:
+ assert(AllFlags == Other.AllFlags && "Cannot drop other flags");
+ break;
+ }
+}
+
FastMathFlags VPIRFlags::getFastMathFlags() const {
assert(OpType == OperationType::FPMathOp &&
"recipe doesn't have fast math flags");
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 10b2f5df2e23e..d86b53dd894fb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2042,9 +2042,9 @@ void VPlanTransforms::cse(VPlan &Plan) {
// V must dominate Def for a valid replacement.
if (!VPDT.dominates(V->getParent(), VPBB))
continue;
- // Drop poison-generating flags when reusing a value.
+ // Only keep flags present on both V and Def.
if (auto *RFlags = dyn_cast<VPRecipeWithIRFlags>(V))
- RFlags->dropPoisonGeneratingFlags();
+ RFlags->intersectFlags(*cast<VPRecipeWithIRFlags>(Def));
Def->replaceAllUsesWith(V);
continue;
}
diff --git a/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll b/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll
index 36c3a2a612d82..db1f2c71e0f77 100644
--- a/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll
+++ b/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll
@@ -16,7 +16,7 @@ define dso_local void @test(ptr %Arr, i32 signext %Len) {
; CHECK: vector.body:
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[INDEX]] to i64
-; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[ARR:%.*]], i64 [[TMP1]]
+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[ARR:%.*]], i64 [[TMP1]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP2]], align 4
; CHECK-NEXT: [[TMP4:%.*]] = call <4 x i32> @llvm.bswap.v4i32(<4 x i32> [[WIDE_LOAD]])
; CHECK-NEXT: store <4 x i32> [[TMP4]], ptr [[TMP2]], align 4
diff --git a/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll b/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll
index df54411f7e710..c2dfce0aa70b8 100644
--- a/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll
@@ -142,8 +142,8 @@ define void @_Z3fn1v() #0 {
; CHECK-NEXT: [[TMP32:%.*]] = add nsw <16 x i64> [[TMP30]], [[VEC_IND37]]
; CHECK-NEXT: [[TMP33:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP32]], i64 0
; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[TMP34]])
-; CHECK-NEXT: [[TMP49:%.*]] = or <16 x i64> [[VEC_IND37]], splat (i64 1)
-; CHECK-NEXT: [[TMP36:%.*]] = add <16 x i64> [[TMP30]], [[TMP49]]
+; CHECK-NEXT: [[TMP49:%.*]] = or disjoint <16 x i64> [[VEC_IND37]], splat (i64 1)
+; CHECK-NEXT: [[TMP36:%.*]] = add nsw <16 x i64> [[TMP30]], [[TMP49]]
; CHECK-NEXT: [[TMP37:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP36]], i64 0
; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP37]], i32 8, <16 x i1> [[TMP34]])
; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 7), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[BROADCAST_SPLAT]])
@@ -191,8 +191,8 @@ define void @_Z3fn1v() #0 {
; CHECK-NEXT: [[TMP46:%.*]] = add nsw <8 x i64> [[TMP44]], [[VEC_IND70]]
; CHECK-NEXT: [[TMP47:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP46]], i64 0
; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[TMP48]])
-; CHECK-NEXT: [[TMP54:%.*]] = or <8 x i64> [[VEC_IND70]], splat (i64 1)
-; CHECK-NEXT: [[TMP50:%.*]] = add <8 x i64> [[TMP44]], [[TMP54]]
+; CHECK-NEXT: [[TMP54:%.*]] = or disjoint <8 x i64> [[VEC_IND70]], splat (i64 1)
+; CHECK-NEXT: [[TMP50:%.*]] = add nsw <8 x i64> [[TMP44]], [[TMP54]]
; CHECK-NEXT: [[TMP51:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP50]], i64 0
; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP51]], i32 8, <8 x i1> [[TMP48]])
; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 7), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[BROADCAST_SPLAT73]])
diff --git a/llvm/test/Transforms/LoopVectorize/flags.ll b/llvm/test/Transforms/LoopVectorize/flags.ll
index cef8ea656afaa..cbdcd50476b98 100644
--- a/llvm/test/Transforms/LoopVectorize/flags.ll
+++ b/llvm/test/Transforms/LoopVectorize/flags.ll
@@ -175,7 +175,7 @@ define void @gep_with_shared_nusw_and_others(i64 %n, ptr %A) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr float, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr nusw float, ptr [[A]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x float>, ptr [[TMP1]], align 4
; CHECK-NEXT: store <4 x float> [[WIDE_LOAD]], ptr [[TMP1]], align 4
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo one error, thanks!
DisjointFlags.IsDisjoint &= Other.DisjointFlags.IsDisjoint; | ||
break; | ||
case OperationType::PossiblyExactOp: | ||
ExactFlags.IsExact = Other.ExactFlags.IsExact; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ExactFlags.IsExact = Other.ExactFlags.IsExact; | |
ExactFlags.IsExact &= Other.ExactFlags.IsExact; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be fixed, thanks. Also added a few missing tests
During CSE, we don't have to drop all poison-generating flags on mis-match, we can keep the ones common on both recipes. PR: llvm/llvm-project#157664
This causes an assertion failure on the following (llvm-reduce'd) test-case, when run as
With error:
(where OpType is "Other", and Other.OpType is "Trunc"). |
thanks for the heads up, taking a look |
During CSE, we don't have to drop all poison-generating flags on mis-match, we can keep the ones common on both recipes.