Skip to content

Commit 6ac7f45

Browse files
committed
Add wasm test and assert on the partial reduction cost
1 parent d9b3a7b commit 6ac7f45

File tree

2 files changed

+184
-17
lines changed

2 files changed

+184
-17
lines changed

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3479,21 +3479,32 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
34793479
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
34803480

34813481
InstructionCost ExtRedCost;
3482+
InstructionCost ExtCost =
3483+
cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
3484+
InstructionCost RedCost = Red->computeCost(VF, Ctx);
3485+
InstructionCost BaseCost = ExtCost + RedCost;
3486+
34823487
if (isa<VPPartialReductionRecipe>(Red)) {
34833488
TargetTransformInfo::PartialReductionExtendKind ExtKind =
34843489
TargetTransformInfo::getPartialReductionExtendKind(ExtOpc);
3485-
ExtRedCost = Ctx.TTI.getPartialReductionCost(
3490+
// The VF ranges have already been clamped for a partial reduction
3491+
// and its existence confirms that it's valid, so we don't need to
3492+
// perform any cost checks or more clamping. Just assert that the
3493+
// partial reduction is still profitable.
3494+
// FIXME: Move partial reduction creation, costing and clamping
3495+
// here.
3496+
InstructionCost Cost = Ctx.TTI.getPartialReductionCost(
34863497
Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
34873498
llvm::TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind);
3499+
assert(Cost <= BaseCost &&
3500+
"Cost of the partial reduction is more than the base cost");
3501+
return true;
34883502
} else {
34893503
ExtRedCost = Ctx.TTI.getExtendedReductionCost(
34903504
Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
34913505
Red->getFastMathFlags(), CostKind);
34923506
}
3493-
InstructionCost ExtCost =
3494-
cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
3495-
InstructionCost RedCost = Red->computeCost(VF, Ctx);
3496-
return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
3507+
return ExtRedCost.isValid() && ExtRedCost < BaseCost;
34973508
},
34983509
Range);
34993510
};
@@ -3535,17 +3546,9 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35353546
VPWidenCastRecipe *OuterExt) -> bool {
35363547
return LoopVectorizationPlanner::getDecisionAndClampRange(
35373548
[&](ElementCount VF) {
3538-
if (IsPartialReduction) {
3539-
// The VF ranges have already been clamped for a partial reduction
3540-
// and its existence confirms that it's valid, so we don't need to
3541-
// perform any cost checks or more clamping.
3542-
// FIXME: Move partial reduction creation, costing and clamping
3543-
// here.
3544-
return true;
3545-
}
3546-
35473549
// Only partial reductions support mixed extends at the moment.
3548-
if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode())
3550+
if (!IsPartialReduction && Ext0 && Ext1 &&
3551+
Ext0->getOpcode() != Ext1->getOpcode())
35493552
return false;
35503553

35513554
bool IsZExt =
@@ -3566,8 +3569,32 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35663569
if (OuterExt)
35673570
ExtCost += OuterExt->computeCost(VF, Ctx);
35683571

3569-
return MulAccCost.isValid() &&
3570-
MulAccCost < ExtCost + MulCost + RedCost;
3572+
InstructionCost BaseCost = ExtCost + MulCost + RedCost;
3573+
3574+
if (IsPartialReduction) {
3575+
Type *SrcTy2 =
3576+
Ext1 ? Ctx.Types.inferScalarType(Ext1->getOperand(0)) : nullptr;
3577+
// The VF ranges have already been clamped for a partial reduction
3578+
// and its existence confirms that it's valid, so we don't need to
3579+
// perform any cost checks or more clamping. Just assert that the
3580+
// partial reduction is still profitable.
3581+
// FIXME: Move partial reduction creation, costing and clamping
3582+
// here.
3583+
InstructionCost Cost = Ctx.TTI.getPartialReductionCost(
3584+
Opcode, SrcTy, SrcTy2, RedTy, VF,
3585+
Ext0 ? TargetTransformInfo::getPartialReductionExtendKind(
3586+
Ext0->getOpcode())
3587+
: TargetTransformInfo::PR_None,
3588+
Ext1 ? TargetTransformInfo::getPartialReductionExtendKind(
3589+
Ext1->getOpcode())
3590+
: TargetTransformInfo::PR_None,
3591+
Mul->getOpcode(), CostKind);
3592+
assert(Cost <= BaseCost &&
3593+
"Cost of the partial reduction is more than the base cost");
3594+
return true;
3595+
}
3596+
3597+
return MulAccCost.isValid() && MulAccCost < BaseCost;
35713598
},
35723599
Range);
35733600
};
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
2+
; RUN: opt -mattr=+simd128 -passes=loop-vectorize %s -S | FileCheck %s
3+
; RUN: opt -mattr=+simd128 -passes=loop-vectorize -vectorizer-maximize-bandwidth %s -S | FileCheck %s --check-prefix=CHECK-MAX-BANDWIDTH
4+
5+
target triple = "wasm32"
6+
7+
define hidden i32 @accumulate_add_u8_u8(ptr noundef readonly %a, ptr noundef readonly %b, i32 noundef %N) {
8+
; CHECK-LABEL: define hidden i32 @accumulate_add_u8_u8(
9+
; CHECK-SAME: ptr noundef readonly [[A:%.*]], ptr noundef readonly [[B:%.*]], i32 noundef [[N:%.*]]) #[[ATTR0:[0-9]+]] {
10+
; CHECK-NEXT: [[ENTRY:.*]]:
11+
; CHECK-NEXT: [[CMP8_NOT:%.*]] = icmp eq i32 [[N]], 0
12+
; CHECK-NEXT: br i1 [[CMP8_NOT]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY_PREHEADER:.*]]
13+
; CHECK: [[FOR_BODY_PREHEADER]]:
14+
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 4
15+
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
16+
; CHECK: [[VECTOR_PH]]:
17+
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i32 [[N]], 4
18+
; CHECK-NEXT: [[N_VEC:%.*]] = sub i32 [[N]], [[N_MOD_VF]]
19+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
20+
; CHECK: [[VECTOR_BODY]]:
21+
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
22+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP8:%.*]], %[[VECTOR_BODY]] ]
23+
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i32 [[INDEX]]
24+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP2]], align 1
25+
; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i8> [[WIDE_LOAD]] to <4 x i32>
26+
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i32 [[INDEX]]
27+
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <4 x i8>, ptr [[TMP5]], align 1
28+
; CHECK-NEXT: [[TMP6:%.*]] = zext <4 x i8> [[WIDE_LOAD1]] to <4 x i32>
29+
; CHECK-NEXT: [[TMP7:%.*]] = add <4 x i32> [[VEC_PHI]], [[TMP3]]
30+
; CHECK-NEXT: [[TMP8]] = add <4 x i32> [[TMP7]], [[TMP6]]
31+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
32+
; CHECK-NEXT: [[TMP9:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
33+
; CHECK-NEXT: br i1 [[TMP9]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
34+
; CHECK: [[MIDDLE_BLOCK]]:
35+
; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP8]])
36+
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
37+
; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP_LOOPEXIT:.*]], label %[[SCALAR_PH]]
38+
; CHECK: [[SCALAR_PH]]:
39+
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[FOR_BODY_PREHEADER]] ]
40+
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP10]], %[[MIDDLE_BLOCK]] ], [ 0, %[[FOR_BODY_PREHEADER]] ]
41+
; CHECK-NEXT: br label %[[FOR_BODY:.*]]
42+
; CHECK: [[FOR_COND_CLEANUP_LOOPEXIT]]:
43+
; CHECK-NEXT: [[ADD3_LCSSA:%.*]] = phi i32 [ [[ADD3:%.*]], %[[FOR_BODY]] ], [ [[TMP10]], %[[MIDDLE_BLOCK]] ]
44+
; CHECK-NEXT: br label %[[FOR_COND_CLEANUP]]
45+
; CHECK: [[FOR_COND_CLEANUP]]:
46+
; CHECK-NEXT: [[RESULT_0_LCSSA:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[ADD3_LCSSA]], %[[FOR_COND_CLEANUP_LOOPEXIT]] ]
47+
; CHECK-NEXT: ret i32 [[RESULT_0_LCSSA]]
48+
; CHECK: [[FOR_BODY]]:
49+
; CHECK-NEXT: [[I_010:%.*]] = phi i32 [ [[INC:%.*]], %[[FOR_BODY]] ], [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ]
50+
; CHECK-NEXT: [[RESULT_09:%.*]] = phi i32 [ [[ADD3]], %[[FOR_BODY]] ], [ [[BC_MERGE_RDX]], %[[SCALAR_PH]] ]
51+
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i32 [[I_010]]
52+
; CHECK-NEXT: [[TMP11:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
53+
; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[TMP11]] to i32
54+
; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i32 [[I_010]]
55+
; CHECK-NEXT: [[TMP12:%.*]] = load i8, ptr [[ARRAYIDX1]], align 1
56+
; CHECK-NEXT: [[CONV2:%.*]] = zext i8 [[TMP12]] to i32
57+
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[RESULT_09]], [[CONV]]
58+
; CHECK-NEXT: [[ADD3]] = add i32 [[ADD]], [[CONV2]]
59+
; CHECK-NEXT: [[INC]] = add nuw i32 [[I_010]], 1
60+
; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i32 [[INC]], [[N]]
61+
; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label %[[FOR_COND_CLEANUP_LOOPEXIT]], label %[[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
62+
;
63+
; CHECK-MAX-BANDWIDTH-LABEL: define hidden i32 @accumulate_add_u8_u8(
64+
; CHECK-MAX-BANDWIDTH-SAME: ptr noundef readonly [[A:%.*]], ptr noundef readonly [[B:%.*]], i32 noundef [[N:%.*]]) #[[ATTR0:[0-9]+]] {
65+
; CHECK-MAX-BANDWIDTH-NEXT: [[ENTRY:.*]]:
66+
; CHECK-MAX-BANDWIDTH-NEXT: [[CMP8_NOT:%.*]] = icmp eq i32 [[N]], 0
67+
; CHECK-MAX-BANDWIDTH-NEXT: br i1 [[CMP8_NOT]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY_PREHEADER:.*]]
68+
; CHECK-MAX-BANDWIDTH: [[FOR_BODY_PREHEADER]]:
69+
; CHECK-MAX-BANDWIDTH-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 16
70+
; CHECK-MAX-BANDWIDTH-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
71+
; CHECK-MAX-BANDWIDTH: [[VECTOR_PH]]:
72+
; CHECK-MAX-BANDWIDTH-NEXT: [[N_MOD_VF:%.*]] = urem i32 [[N]], 16
73+
; CHECK-MAX-BANDWIDTH-NEXT: [[N_VEC:%.*]] = sub i32 [[N]], [[N_MOD_VF]]
74+
; CHECK-MAX-BANDWIDTH-NEXT: br label %[[VECTOR_BODY:.*]]
75+
; CHECK-MAX-BANDWIDTH: [[VECTOR_BODY]]:
76+
; CHECK-MAX-BANDWIDTH-NEXT: [[INDEX:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
77+
; CHECK-MAX-BANDWIDTH-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE2:%.*]], %[[VECTOR_BODY]] ]
78+
; CHECK-MAX-BANDWIDTH-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i32 [[INDEX]]
79+
; CHECK-MAX-BANDWIDTH-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
80+
; CHECK-MAX-BANDWIDTH-NEXT: [[TMP5:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i32 [[INDEX]]
81+
; CHECK-MAX-BANDWIDTH-NEXT: [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP5]], align 1
82+
; CHECK-MAX-BANDWIDTH-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
83+
; CHECK-MAX-BANDWIDTH-NEXT: [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP3]])
84+
; CHECK-MAX-BANDWIDTH-NEXT: [[TMP6:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
85+
; CHECK-MAX-BANDWIDTH-NEXT: [[PARTIAL_REDUCE2]] = call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP6]])
86+
; CHECK-MAX-BANDWIDTH-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 16
87+
; CHECK-MAX-BANDWIDTH-NEXT: [[TMP9:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
88+
; CHECK-MAX-BANDWIDTH-NEXT: br i1 [[TMP9]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
89+
; CHECK-MAX-BANDWIDTH: [[MIDDLE_BLOCK]]:
90+
; CHECK-MAX-BANDWIDTH-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE2]])
91+
; CHECK-MAX-BANDWIDTH-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[N]], [[N_VEC]]
92+
; CHECK-MAX-BANDWIDTH-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP_LOOPEXIT:.*]], label %[[SCALAR_PH]]
93+
; CHECK-MAX-BANDWIDTH: [[SCALAR_PH]]:
94+
; CHECK-MAX-BANDWIDTH-NEXT: [[BC_RESUME_VAL:%.*]] = phi i32 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[FOR_BODY_PREHEADER]] ]
95+
; CHECK-MAX-BANDWIDTH-NEXT: [[BC_MERGE_RDX:%.*]] = phi i32 [ [[TMP7]], %[[MIDDLE_BLOCK]] ], [ 0, %[[FOR_BODY_PREHEADER]] ]
96+
; CHECK-MAX-BANDWIDTH-NEXT: br label %[[FOR_BODY:.*]]
97+
; CHECK-MAX-BANDWIDTH: [[FOR_COND_CLEANUP_LOOPEXIT]]:
98+
; CHECK-MAX-BANDWIDTH-NEXT: [[ADD3_LCSSA:%.*]] = phi i32 [ [[ADD3:%.*]], %[[FOR_BODY]] ], [ [[TMP7]], %[[MIDDLE_BLOCK]] ]
99+
; CHECK-MAX-BANDWIDTH-NEXT: br label %[[FOR_COND_CLEANUP]]
100+
; CHECK-MAX-BANDWIDTH: [[FOR_COND_CLEANUP]]:
101+
; CHECK-MAX-BANDWIDTH-NEXT: [[RESULT_0_LCSSA:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[ADD3_LCSSA]], %[[FOR_COND_CLEANUP_LOOPEXIT]] ]
102+
; CHECK-MAX-BANDWIDTH-NEXT: ret i32 [[RESULT_0_LCSSA]]
103+
; CHECK-MAX-BANDWIDTH: [[FOR_BODY]]:
104+
; CHECK-MAX-BANDWIDTH-NEXT: [[I_010:%.*]] = phi i32 [ [[INC:%.*]], %[[FOR_BODY]] ], [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ]
105+
; CHECK-MAX-BANDWIDTH-NEXT: [[RESULT_09:%.*]] = phi i32 [ [[ADD3]], %[[FOR_BODY]] ], [ [[BC_MERGE_RDX]], %[[SCALAR_PH]] ]
106+
; CHECK-MAX-BANDWIDTH-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i32 [[I_010]]
107+
; CHECK-MAX-BANDWIDTH-NEXT: [[TMP11:%.*]] = load i8, ptr [[ARRAYIDX]], align 1
108+
; CHECK-MAX-BANDWIDTH-NEXT: [[CONV:%.*]] = zext i8 [[TMP11]] to i32
109+
; CHECK-MAX-BANDWIDTH-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i32 [[I_010]]
110+
; CHECK-MAX-BANDWIDTH-NEXT: [[TMP12:%.*]] = load i8, ptr [[ARRAYIDX1]], align 1
111+
; CHECK-MAX-BANDWIDTH-NEXT: [[CONV2:%.*]] = zext i8 [[TMP12]] to i32
112+
; CHECK-MAX-BANDWIDTH-NEXT: [[ADD:%.*]] = add i32 [[RESULT_09]], [[CONV]]
113+
; CHECK-MAX-BANDWIDTH-NEXT: [[ADD3]] = add i32 [[ADD]], [[CONV2]]
114+
; CHECK-MAX-BANDWIDTH-NEXT: [[INC]] = add nuw i32 [[I_010]], 1
115+
; CHECK-MAX-BANDWIDTH-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i32 [[INC]], [[N]]
116+
; CHECK-MAX-BANDWIDTH-NEXT: br i1 [[EXITCOND_NOT]], label %[[FOR_COND_CLEANUP_LOOPEXIT]], label %[[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
117+
;
118+
entry:
119+
%cmp8.not = icmp eq i32 %N, 0
120+
br i1 %cmp8.not, label %for.cond.cleanup, label %for.body
121+
122+
for.cond.cleanup: ; preds = %for.body, %entry
123+
%result.0.lcssa = phi i32 [ 0, %entry ], [ %add3, %for.body ]
124+
ret i32 %result.0.lcssa
125+
126+
for.body: ; preds = %entry, %for.body
127+
%i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
128+
%result.09 = phi i32 [ %add3, %for.body ], [ 0, %entry ]
129+
%arrayidx = getelementptr inbounds nuw i8, ptr %a, i32 %i.010
130+
%0 = load i8, ptr %arrayidx, align 1
131+
%conv = zext i8 %0 to i32
132+
%arrayidx1 = getelementptr inbounds nuw i8, ptr %b, i32 %i.010
133+
%1 = load i8, ptr %arrayidx1, align 1
134+
%conv2 = zext i8 %1 to i32
135+
%add = add i32 %result.09, %conv
136+
%add3 = add i32 %add, %conv2
137+
%inc = add nuw i32 %i.010, 1
138+
%exitcond.not = icmp eq i32 %inc, %N
139+
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
140+
}

0 commit comments

Comments
 (0)