Skip to content

Conversation

fhahn
Copy link
Contributor

@fhahn fhahn commented Sep 9, 2025

During CSE, we don't have to drop all poison-generating flags on mis-match, we can keep the ones common on both recipes.

During CSE, we don't have to drop poison-generating flags, if both the
re-used recipe and the to-be-replaced recipe have the same flags.
@llvmbot
Copy link
Member

llvmbot commented Sep 9, 2025

@llvm/pr-subscribers-vectorizers
@llvm/pr-subscribers-backend-powerpc

@llvm/pr-subscribers-llvm-transforms

Author: Florian Hahn (fhahn)

Changes

During CSE, we don't have to drop all poison-generating flags on mis-match, we can keep the ones common on both recipes.


Full diff: https://github.com/llvm/llvm-project/pull/157664.diff

6 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+4)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+36)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+2-2)
  • (modified) llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll (+1-1)
  • (modified) llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll (+4-4)
  • (modified) llvm/test/Transforms/LoopVectorize/flags.ll (+1-1)
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index b93bdf244237e..53291a931530f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -721,6 +721,10 @@ class VPIRFlags {
     AllFlags = Other.AllFlags;
   }
 
+  /// Only keep flags also present in \p Other. \p Other must have the same
+  /// OpType as the current object.
+  void intersectFlags(const VPIRFlags &Other);
+
   /// Drop all poison-generating flags.
   void dropPoisonGeneratingFlags() {
     // NOTE: This needs to be kept in-sync with
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 46162a9276469..9f1311fbd0687 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -392,6 +392,42 @@ void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 }
 #endif
 
+void VPIRFlags::intersectFlags(const VPIRFlags &Other) {
+  assert(OpType == Other.OpType && "OpType must match");
+  switch (OpType) {
+  case OperationType::OverflowingBinOp:
+    WrapFlags.HasNUW &= Other.WrapFlags.HasNUW;
+    WrapFlags.HasNSW &= Other.WrapFlags.HasNSW;
+    break;
+  case OperationType::Trunc:
+    TruncFlags.HasNUW &= Other.TruncFlags.HasNUW;
+    TruncFlags.HasNSW &= Other.TruncFlags.HasNSW;
+    break;
+  case OperationType::DisjointOp:
+    DisjointFlags.IsDisjoint &= Other.DisjointFlags.IsDisjoint;
+    break;
+  case OperationType::PossiblyExactOp:
+    ExactFlags.IsExact = Other.ExactFlags.IsExact;
+    break;
+  case OperationType::GEPOp:
+    GEPFlags &= Other.GEPFlags;
+    break;
+  case OperationType::FPMathOp:
+    FMFs.NoNaNs &= Other.FMFs.NoNaNs;
+    FMFs.NoInfs &= Other.FMFs.NoInfs;
+    break;
+  case OperationType::NonNegOp:
+    NonNegFlags.NonNeg &= Other.NonNegFlags.NonNeg;
+    break;
+  case OperationType::Cmp:
+    assert(CmpPredicate == Other.CmpPredicate && "Cannot drop CmpPredicate");
+    break;
+  case OperationType::Other:
+    assert(AllFlags == Other.AllFlags && "Cannot drop other flags");
+    break;
+  }
+}
+
 FastMathFlags VPIRFlags::getFastMathFlags() const {
   assert(OpType == OperationType::FPMathOp &&
          "recipe doesn't have fast math flags");
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 10b2f5df2e23e..d86b53dd894fb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2042,9 +2042,9 @@ void VPlanTransforms::cse(VPlan &Plan) {
         // V must dominate Def for a valid replacement.
         if (!VPDT.dominates(V->getParent(), VPBB))
           continue;
-        // Drop poison-generating flags when reusing a value.
+        // Only keep flags present on both V and Def.
         if (auto *RFlags = dyn_cast<VPRecipeWithIRFlags>(V))
-          RFlags->dropPoisonGeneratingFlags();
+          RFlags->intersectFlags(*cast<VPRecipeWithIRFlags>(Def));
         Def->replaceAllUsesWith(V);
         continue;
       }
diff --git a/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll b/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll
index 36c3a2a612d82..db1f2c71e0f77 100644
--- a/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll
+++ b/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll
@@ -16,7 +16,7 @@ define dso_local void @test(ptr %Arr, i32 signext %Len) {
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP1:%.*]] = sext i32 [[INDEX]] to i64
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i32, ptr [[ARR:%.*]], i64 [[TMP1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[ARR:%.*]], i64 [[TMP1]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP2]], align 4
 ; CHECK-NEXT:    [[TMP4:%.*]] = call <4 x i32> @llvm.bswap.v4i32(<4 x i32> [[WIDE_LOAD]])
 ; CHECK-NEXT:    store <4 x i32> [[TMP4]], ptr [[TMP2]], align 4
diff --git a/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll b/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll
index df54411f7e710..c2dfce0aa70b8 100644
--- a/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll
@@ -142,8 +142,8 @@ define void @_Z3fn1v() #0 {
 ; CHECK-NEXT:    [[TMP32:%.*]] = add nsw <16 x i64> [[TMP30]], [[VEC_IND37]]
 ; CHECK-NEXT:    [[TMP33:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP32]], i64 0
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[TMP34]])
-; CHECK-NEXT:    [[TMP49:%.*]] = or <16 x i64> [[VEC_IND37]], splat (i64 1)
-; CHECK-NEXT:    [[TMP36:%.*]] = add <16 x i64> [[TMP30]], [[TMP49]]
+; CHECK-NEXT:    [[TMP49:%.*]] = or disjoint <16 x i64> [[VEC_IND37]], splat (i64 1)
+; CHECK-NEXT:    [[TMP36:%.*]] = add nsw <16 x i64> [[TMP30]], [[TMP49]]
 ; CHECK-NEXT:    [[TMP37:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP36]], i64 0
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP37]], i32 8, <16 x i1> [[TMP34]])
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 7), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[BROADCAST_SPLAT]])
@@ -191,8 +191,8 @@ define void @_Z3fn1v() #0 {
 ; CHECK-NEXT:    [[TMP46:%.*]] = add nsw <8 x i64> [[TMP44]], [[VEC_IND70]]
 ; CHECK-NEXT:    [[TMP47:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP46]], i64 0
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[TMP48]])
-; CHECK-NEXT:    [[TMP54:%.*]] = or <8 x i64> [[VEC_IND70]], splat (i64 1)
-; CHECK-NEXT:    [[TMP50:%.*]] = add <8 x i64> [[TMP44]], [[TMP54]]
+; CHECK-NEXT:    [[TMP54:%.*]] = or disjoint <8 x i64> [[VEC_IND70]], splat (i64 1)
+; CHECK-NEXT:    [[TMP50:%.*]] = add nsw <8 x i64> [[TMP44]], [[TMP54]]
 ; CHECK-NEXT:    [[TMP51:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP50]], i64 0
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP51]], i32 8, <8 x i1> [[TMP48]])
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 7), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[BROADCAST_SPLAT73]])
diff --git a/llvm/test/Transforms/LoopVectorize/flags.ll b/llvm/test/Transforms/LoopVectorize/flags.ll
index cef8ea656afaa..cbdcd50476b98 100644
--- a/llvm/test/Transforms/LoopVectorize/flags.ll
+++ b/llvm/test/Transforms/LoopVectorize/flags.ll
@@ -175,7 +175,7 @@ define void @gep_with_shared_nusw_and_others(i64 %n, ptr %A) {
 ; CHECK-NEXT:    br label %[[VECTOR_BODY:.*]]
 ; CHECK:       [[VECTOR_BODY]]:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr float, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr nusw float, ptr [[A]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x float>, ptr [[TMP1]], align 4
 ; CHECK-NEXT:    store <4 x float> [[WIDE_LOAD]], ptr [[TMP1]], align 4
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4

Copy link
Contributor

@artagnon artagnon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo one error, thanks!

DisjointFlags.IsDisjoint &= Other.DisjointFlags.IsDisjoint;
break;
case OperationType::PossiblyExactOp:
ExactFlags.IsExact = Other.ExactFlags.IsExact;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ExactFlags.IsExact = Other.ExactFlags.IsExact;
ExactFlags.IsExact &= Other.ExactFlags.IsExact;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed, thanks. Also added a few missing tests

@fhahn fhahn enabled auto-merge (squash) September 10, 2025 09:49
@fhahn fhahn merged commit c3e76b2 into llvm:main Sep 10, 2025
9 checks passed
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Sep 10, 2025
During CSE, we don't have to drop all poison-generating flags on
mis-match, we can keep the ones common on both recipes.

PR: llvm/llvm-project#157664
@fhahn fhahn deleted the vplan-cse-retain-matching-flags branch September 10, 2025 12:12
@jyknight
Copy link
Member

This causes an assertion failure on the following (llvm-reduce'd) test-case, when run as opt -passes='loop-vectorize<no-interleave-forced-only;no-vectorize-forced-only;>'.

target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-linux-gnu"

define ptr @test(ptr %0) {
  br label %loop

loop:
  %2 = phi i64 [ 0, %1 ], [ %12, %loop ]
  %3 = zext i1 false to i64
  %4 = load i8, ptr %0, align 1
  %5 = and i8 %4, 1
  %6 = trunc i8 %4 to i1
  %7 = select i1 %6, float 1.000000e+00, float 0.000000e+00
  %8 = trunc i8 %5 to i1
  %9 = select i1 %8, float 0.000000e+00, float %7
  %10 = bitcast float %9 to i32
  %11 = trunc i32 %10 to i8
  store i8 %11, ptr null, align 1
  %12 = add i64 %2, 1
  %exitcond.not = icmp eq i64 %2, 1
  br i1 %exitcond.not, label %exit, label %loop

exit:
  ret ptr null
}

With error:

opt: llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp:396: void llvm::VPIRFlags::intersectFlags(const VPIRFlags &): Assertion `OpType == Other.OpType && "OpType must match"' failed.

(where OpType is "Other", and Other.OpType is "Trunc").

@fhahn
Copy link
Contributor Author

fhahn commented Sep 24, 2025

thanks for the heads up, taking a look

@fhahn
Copy link
Contributor Author

fhahn commented Sep 25, 2025

@jyknight should be fixed as part of the fix for #160396

@jyknight
Copy link
Member

@jyknight should be fixed as part of the fix for #160396

Thanks! Verified the fix on the original code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants