Skip to content

Commit 13aeeb7

Browse files
committed
[LV] Support multiplies by constants when forming scaled reductions.
We can create partial reductions for multiplies with constants, if the constant is small enough to be extended from source to destination type w/o changing the value. This only handles constant on the right side of a multiply, relying on other passes to canonicalize the input. Alive2 Proofs: https://alive2.llvm.org/ce/z/iWRMr6
1 parent cbaf3c6 commit 13aeeb7

File tree

5 files changed

+47
-17
lines changed

5 files changed

+47
-17
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7937,6 +7937,13 @@ bool VPRecipeBuilder::getScaledReductions(
79377937
auto CollectExtInfo = [this, &Exts, &ExtOpTypes,
79387938
&ExtKinds](SmallVectorImpl<Value *> &Ops) -> bool {
79397939
for (const auto &[I, OpI] : enumerate(Ops)) {
7940+
auto *CI = dyn_cast<ConstantInt>(OpI);
7941+
if (I > 0 && CI &&
7942+
canConstantBeExtended(CI, ExtOpTypes[0], ExtKinds[0])) {
7943+
ExtOpTypes[I] = ExtOpTypes[0];
7944+
ExtKinds[I] = ExtKinds[0];
7945+
continue;
7946+
}
79407947
Value *ExtOp;
79417948
if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
79427949
return false;

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,6 +1741,16 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) {
17411741
}
17421742
#endif
17431743

1744+
bool llvm::canConstantBeExtended(const ConstantInt *CI, Type *NarrowType,
1745+
TTI::PartialReductionExtendKind ExtKind) {
1746+
APInt TruncatedVal = CI->getValue().trunc(NarrowType->getScalarSizeInBits());
1747+
unsigned WideSize = CI->getType()->getScalarSizeInBits();
1748+
APInt ExtendedVal = ExtKind == TTI::PR_SignExtend
1749+
? TruncatedVal.sext(WideSize)
1750+
: TruncatedVal.zext(WideSize);
1751+
return ExtendedVal == CI->getValue();
1752+
}
1753+
17441754
TargetTransformInfo::OperandValueInfo
17451755
VPCostContext::getOperandInfo(VPValue *V) const {
17461756
if (!V->isLiveIn())

llvm/lib/Transforms/Vectorize/VPlanHelpers.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,10 @@ class VPlanPrinter {
468468
};
469469
#endif
470470

471+
/// Check if a constant \p CI can be safely treated as having been extended
472+
/// from a narrower type with the given extension kind.
473+
bool canConstantBeExtended(const ConstantInt *CI, Type *NarrowType,
474+
TTI::PartialReductionExtendKind ExtKind);
471475
} // end namespace llvm
472476

473477
#endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_H

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,15 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
340340
: Widen->getOperand(1));
341341
ExtAType = GetExtendKind(ExtAR);
342342
ExtBType = GetExtendKind(ExtBR);
343+
344+
if (!ExtBR && Widen->getOperand(1)->isLiveIn()) {
345+
auto *CI =
346+
dyn_cast<ConstantInt>(Widen->getOperand(1)->getLiveInIRValue());
347+
if (CI && canConstantBeExtended(CI, InputTypeA, ExtAType)) {
348+
InputTypeB = InputTypeA;
349+
ExtBType = ExtAType;
350+
}
351+
}
343352
};
344353

345354
if (isa<VPWidenCastRecipe>(OpR)) {

llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-constant-ops.ll

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,22 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) {
2020
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
2121
; CHECK: [[VECTOR_BODY]]:
2222
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
23-
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
23+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
2424
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
2525
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
2626
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
2727
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63)
28-
; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
28+
; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
2929
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
30-
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
31-
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
30+
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
31+
; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
3232
; CHECK: [[MIDDLE_BLOCK]]:
33-
; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
33+
; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
3434
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
3535
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
3636
; CHECK: [[SCALAR_PH]]:
3737
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ]
38-
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
38+
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
3939
; CHECK-NEXT: br label %[[LOOP:.*]]
4040
; CHECK: [[LOOP]]:
4141
; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ]
@@ -48,7 +48,7 @@ define i32 @red_zext_mul_by_63(ptr %start, ptr %end) {
4848
; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]]
4949
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP3:![0-9]+]]
5050
; CHECK: [[EXIT]]:
51-
; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ]
51+
; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ]
5252
; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]]
5353
;
5454
entry:
@@ -86,17 +86,17 @@ define i32 @red_zext_mul_by_255(ptr %start, ptr %end) {
8686
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
8787
; CHECK: [[VECTOR_BODY]]:
8888
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
89-
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
89+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
9090
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
9191
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
9292
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
9393
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 255)
94-
; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
94+
; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
9595
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
9696
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
9797
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
9898
; CHECK: [[MIDDLE_BLOCK]]:
99-
; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
99+
; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
100100
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
101101
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
102102
; CHECK: [[SCALAR_PH]]:
@@ -218,22 +218,22 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) {
218218
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
219219
; CHECK: [[VECTOR_BODY]]:
220220
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
221-
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
221+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
222222
; CHECK-NEXT: [[NEXT_GEP:%.*]] = getelementptr i8, ptr [[START]], i64 [[INDEX]]
223223
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[NEXT_GEP]], align 1
224224
; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
225225
; CHECK-NEXT: [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], splat (i32 63)
226-
; CHECK-NEXT: [[TMP5]] = add <16 x i32> [[VEC_PHI]], [[TMP4]]
226+
; CHECK-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
227227
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
228-
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
229-
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
228+
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
229+
; CHECK-NEXT: br i1 [[TMP5]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
230230
; CHECK: [[MIDDLE_BLOCK]]:
231-
; CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP5]])
231+
; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE]])
232232
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
233233
; CHECK-NEXT: br i1 [[CMP_N]], label %[[EXIT:.*]], label %[[SCALAR_PH]]
234234
; CHECK: [[SCALAR_PH]]:
235235
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi ptr [ [[TMP2]], %[[MIDDLE_BLOCK]] ], [ [[START]], %[[ENTRY]] ]
236-
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
236+
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP6]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
237237
; CHECK-NEXT: br label %[[LOOP:.*]]
238238
; CHECK: [[LOOP]]:
239239
; CHECK-NEXT: [[PTR_IV:%.*]] = phi ptr [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[GEP_IV_NEXT:%.*]], %[[LOOP]] ]
@@ -246,7 +246,7 @@ define i32 @red_sext_mul_by_63(ptr %start, ptr %end) {
246246
; CHECK-NEXT: [[EC:%.*]] = icmp eq ptr [[PTR_IV]], [[END]]
247247
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP9:![0-9]+]]
248248
; CHECK: [[EXIT]]:
249-
; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ]
249+
; CHECK-NEXT: [[RED_NEXT_LCSSA:%.*]] = phi i32 [ [[RED_NEXT]], %[[LOOP]] ], [ [[TMP6]], %[[MIDDLE_BLOCK]] ]
250250
; CHECK-NEXT: ret i32 [[RED_NEXT_LCSSA]]
251251
;
252252
entry:

0 commit comments

Comments
 (0)