Skip to content
Open
2 changes: 2 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ class TargetTransformInfo {
/// Get the kind of extension that an instruction represents.
LLVM_ABI static PartialReductionExtendKind
getPartialReductionExtendKind(Instruction *I);
LLVM_ABI static PartialReductionExtendKind
getPartialReductionExtendKind(Instruction::CastOps CastOpc);

/// Construct a TTI object using a type implementing the \c Concept
/// API below.
Expand Down
19 changes: 15 additions & 4 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,13 +1001,24 @@ InstructionCost TargetTransformInfo::getShuffleCost(

TargetTransformInfo::PartialReductionExtendKind
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
if (isa<SExtInst>(I))
return PR_SignExtend;
if (isa<ZExtInst>(I))
return PR_ZeroExtend;
if (auto *Cast = dyn_cast<CastInst>(I))
return getPartialReductionExtendKind(Cast->getOpcode());
return PR_None;
}

TargetTransformInfo::PartialReductionExtendKind
TargetTransformInfo::getPartialReductionExtendKind(
Instruction::CastOps CastOpc) {
switch (CastOpc) {
case Instruction::CastOps::ZExt:
return PR_ZeroExtend;
case Instruction::CastOps::SExt:
return PR_SignExtend;
default:
return PR_None;
}
}

TTI::CastContextHint
TargetTransformInfo::getCastContextHint(const Instruction *I) {
if (!I)
Expand Down
8 changes: 6 additions & 2 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -2710,7 +2710,8 @@ class LLVM_ABI_FOR_TEST VPReductionRecipe : public VPRecipeWithIRFlags {

static inline bool classof(const VPRecipeBase *R) {
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
R->getVPDefID() == VPRecipeBase::VPPartialReductionSC;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this was missed before and only now is tested?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's right.

}

static inline bool classof(const VPUser *U) {
Expand Down Expand Up @@ -2772,7 +2773,10 @@ class VPPartialReductionRecipe : public VPReductionRecipe {
Opcode(Opcode), VFScaleFactor(ScaleFactor) {
[[maybe_unused]] auto *AccumulatorRecipe =
getChainOp()->getDefiningRecipe();
assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
// When cloning as part of a VPExpressionRecipe the chain op could have
// replaced by a temporary VPValue, so it doesn't have a defining recipe.
assert((!AccumulatorRecipe ||
isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
"Unexpected operand order for partial reduction recipe");
}
Expand Down
49 changes: 35 additions & 14 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects();
case VPBlendSC:
case VPReductionEVLSC:
case VPPartialReductionSC:
case VPReductionSC:
case VPScalarIVStepsSC:
case VPVectorPointerSC:
Expand Down Expand Up @@ -300,11 +301,11 @@ InstructionCost
VPPartialReductionRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
std::optional<unsigned> Opcode;
VPValue *Op = getOperand(0);
VPValue *Op = getOperand(1);
VPRecipeBase *OpR = Op->getDefiningRecipe();

// If the partial reduction is predicated, a select will be operand 0
if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) {
// If the partial reduction is predicated, a select will be operand 1
if (match(Op, m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) {
OpR = Op->getDefiningRecipe();
}

Expand Down Expand Up @@ -2841,11 +2842,18 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
cast<VPReductionRecipe>(ExpressionRecipes.back())->getRecurrenceKind());
switch (ExpressionType) {
case ExpressionTypes::ExtendedReduction: {
unsigned Opcode = RecurrenceDescriptor::getOpcode(
cast<VPReductionRecipe>(ExpressionRecipes[1])->getRecurrenceKind());
auto *ExtR = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) {
return Ctx.TTI.getPartialReductionCost(
Opcode, Ctx.Types.inferScalarType(getOperand(0)), nullptr, RedTy, VF,
TargetTransformInfo::getPartialReductionExtendKind(ExtR->getOpcode()),
TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind);
}
return Ctx.TTI.getExtendedReductionCost(
Opcode,
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
Instruction::ZExt,
RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
Opcode, ExtR->getOpcode() == Instruction::ZExt, RedTy, SrcVecTy,
std::nullopt, Ctx.CostKind);
}
case ExpressionTypes::MulAccReduction:
return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy,
Expand All @@ -2856,6 +2864,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
Opcode = Instruction::Sub;
LLVM_FALLTHROUGH;
case ExpressionTypes::ExtMulAccReduction: {
if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) {
auto *Ext0R = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
auto *Ext1R = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
Comment on lines +2868 to +2869
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work as expected for all test on current main? I think at least in some cases one of the operands may be a constant live-in.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the matching function in VPlanTransforms explicitly checks for two extends, so the constant variant doesn't get bundled currently. I'm happy to get that working separately if necessary.

auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
return Ctx.TTI.getPartialReductionCost(
Opcode, Ctx.Types.inferScalarType(getOperand(0)),
Ctx.Types.inferScalarType(getOperand(1)), RedTy, VF,
TargetTransformInfo::getPartialReductionExtendKind(
Ext0R->getOpcode()),
TargetTransformInfo::getPartialReductionExtendKind(
Ext1R->getOpcode()),
Mul->getOpcode(), Ctx.CostKind);
}
return Ctx.TTI.getMulAccReductionCost(
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
Instruction::ZExt,
Expand Down Expand Up @@ -2888,12 +2909,13 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
O << " = ";
auto *Red = cast<VPReductionRecipe>(ExpressionRecipes.back());
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);

switch (ExpressionType) {
case ExpressionTypes::ExtendedReduction: {
getOperand(1)->printAsOperand(O, SlotTracker);
O << " +";
O << " reduce." << Instruction::getOpcodeName(Opcode) << " (";
O << " + " << (IsPartialReduction ? "partial." : "") << "reduce.";
O << Instruction::getOpcodeName(Opcode) << " (";
getOperand(0)->printAsOperand(O, SlotTracker);
Red->printFlags(O);

Expand All @@ -2909,8 +2931,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
}
case ExpressionTypes::ExtNegatedMulAccReduction: {
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
O << " + reduce."
<< Instruction::getOpcodeName(
O << " + " << (IsPartialReduction ? "partial." : "") << "reduce.";
O << Instruction::getOpcodeName(
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
<< " (sub (0, mul";
auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
Expand All @@ -2934,9 +2956,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
case ExpressionTypes::MulAccReduction:
case ExpressionTypes::ExtMulAccReduction: {
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
O << " + ";
O << "reduce."
<< Instruction::getOpcodeName(
O << " + " << (IsPartialReduction ? "partial." : "") << "reduce.";
O << Instruction::getOpcodeName(
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
<< " (";
O << "mul";
Expand Down
80 changes: 56 additions & 24 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3465,18 +3465,31 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
VPValue *VecOp = Red->getVecOp();

// Clamp the range if using extended-reduction is profitable.
auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
Type *SrcTy) -> bool {
auto IsExtendedRedValidAndClampRange =
[&](unsigned Opcode, Instruction::CastOps ExtOpc, Type *SrcTy) -> bool {
return LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
Opcode, isZExt, RedTy, SrcVecTy, Red->getFastMathFlags(),
CostKind);

InstructionCost ExtRedCost;
InstructionCost ExtCost =
cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
InstructionCost RedCost = Red->computeCost(VF, Ctx);

if (isa<VPPartialReductionRecipe>(Red)) {
TargetTransformInfo::PartialReductionExtendKind ExtKind =
TargetTransformInfo::getPartialReductionExtendKind(ExtOpc);
// FIXME: Move partial reduction creation, costing and clamping
// here from LoopVectorize.cpp.
ExtRedCost = Ctx.TTI.getPartialReductionCost(
Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
llvm::TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind);
} else {
ExtRedCost = Ctx.TTI.getExtendedReductionCost(
Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
Red->getFastMathFlags(), CostKind);
}
return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
},
Range);
Expand All @@ -3487,8 +3500,7 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
IsExtendedRedValidAndClampRange(
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()),
cast<VPWidenCastRecipe>(VecOp)->getOpcode() ==
Instruction::CastOps::ZExt,
cast<VPWidenCastRecipe>(VecOp)->getOpcode(),
Ctx.Types.inferScalarType(A)))
return new VPExpressionRecipe(cast<VPWidenCastRecipe>(VecOp), Red);

Expand All @@ -3506,6 +3518,8 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
static VPExpressionRecipe *
tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
VPCostContext &Ctx, VFRange &Range) {
bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);

unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
return nullptr;
Expand All @@ -3514,16 +3528,41 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,

// Clamp the range if using multiply-accumulate-reduction is profitable.
auto IsMulAccValidAndClampRange =
[&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
[&](VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1,
VPWidenCastRecipe *OuterExt) -> bool {
return LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Type *SrcTy =
Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
isZExt, Opcode, RedTy, SrcVecTy, CostKind);
InstructionCost MulAccCost;

if (IsPartialReduction) {
Type *SrcTy2 =
Ext1 ? Ctx.Types.inferScalarType(Ext1->getOperand(0)) : nullptr;
// FIXME: Move partial reduction creation, costing and clamping
// here from LoopVectorize.cpp.
MulAccCost = Ctx.TTI.getPartialReductionCost(
Opcode, SrcTy, SrcTy2, RedTy, VF,
Ext0 ? TargetTransformInfo::getPartialReductionExtendKind(
Ext0->getOpcode())
: TargetTransformInfo::PR_None,
Ext1 ? TargetTransformInfo::getPartialReductionExtendKind(
Ext1->getOpcode())
: TargetTransformInfo::PR_None,
Mul->getOpcode(), CostKind);
} else {
// Only partial reductions support mixed extends at the moment.
if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode())
return false;

bool IsZExt =
!Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt;
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
MulAccCost = Ctx.TTI.getMulAccReductionCost(IsZExt, Opcode, RedTy,
SrcVecTy, CostKind);
}

InstructionCost MulCost = Mul->computeCost(VF, Ctx);
InstructionCost RedCost = Red->computeCost(VF, Ctx);
InstructionCost ExtCost = 0;
Expand Down Expand Up @@ -3558,23 +3597,18 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());

// Match reduce.add(mul(ext, ext)).
if (RecipeA && RecipeB &&
(RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) &&
match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
// Match reduce.add/sub(mul(ext, ext)).
if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
Instruction::CastOps::ZExt,
Mul, RecipeA, RecipeB, nullptr)) {
IsMulAccValidAndClampRange(Mul, RecipeA, RecipeB, nullptr)) {
if (Sub)
return new VPExpressionRecipe(RecipeA, RecipeB, Mul,
cast<VPWidenRecipe>(Sub), Red);
return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
}
// Match reduce.add(mul).
// TODO: Add an expression type for this variant with a negated mul
if (!Sub &&
IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
if (!Sub && IsMulAccValidAndClampRange(Mul, nullptr, nullptr, nullptr))
return new VPExpressionRecipe(Mul, Red);
}
// TODO: Add an expression type for negated versions of other expression
Expand All @@ -3594,9 +3628,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
Ext0->getOpcode() == Ext1->getOpcode() &&
IsMulAccValidAndClampRange(Ext0->getOpcode() ==
Instruction::CastOps::ZExt,
Mul, Ext0, Ext1, Ext)) {
IsMulAccValidAndClampRange(Mul, Ext0, Ext1, Ext)) {
auto *NewExt0 = new VPWidenCastRecipe(
Ext0->getOpcode(), Ext0->getOperand(0), Ext->getResultType(), *Ext0,
*Ext0, Ext0->getDebugLoc());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ define i64 @test_two_ivs(ptr %a, ptr %b, i64 %start) #0 {
; CHECK-NEXT: Cost of 0 for VF 16: induction instruction %i.iv = phi i64 [ 0, %entry ], [ %i.iv.next, %for.body ]
; CHECK-NEXT: Cost of 0 for VF 16: induction instruction %j.iv = phi i64 [ %start, %entry ], [ %j.iv.next, %for.body ]
; CHECK-NEXT: Cost of 0 for VF 16: EMIT vp<{{.+}}> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
; CHECK: Cost for VF 16: 41
; CHECK: Cost for VF 16: 3
; CHECK: LV: Selecting VF: 16
entry:
br label %for.body
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
target triple = "aarch64-none-unknown-elf"

define i32 @dotp(ptr %a, ptr %b) #0 {
; CHECK-REGS-VP-NOT: LV(REG): Not considering vector loop of width vscale x 16 because it uses too many registers
; CHECK-REGS-VP: LV: Selecting VF: vscale x 8.
; CHECK-REGS-VP: LV: Selecting VF: vscale x 16.
;
; CHECK-NOREGS-VP: LV(REG): Not considering vector loop of width vscale x 8 because it uses too many registers
; CHECK-NOREGS-VP: LV(REG): Not considering vector loop of width vscale x 16 because it uses too many registers
Expand Down
Loading