-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[LV] Bundle partial reductions inside VPExpressionRecipe #147302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7015be2
66ef470
04c749b
7c3ed6e
5f8c472
0c75668
893ab05
43b1b6a
281a307
f9d4b8b
fa6660d
527bbbf
d634251
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -168,6 +168,7 @@ bool VPRecipeBase::mayHaveSideEffects() const { | |
return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects(); | ||
case VPBlendSC: | ||
case VPReductionEVLSC: | ||
case VPPartialReductionSC: | ||
case VPReductionSC: | ||
case VPScalarIVStepsSC: | ||
case VPVectorPointerSC: | ||
|
@@ -300,11 +301,11 @@ InstructionCost | |
VPPartialReductionRecipe::computeCost(ElementCount VF, | ||
VPCostContext &Ctx) const { | ||
std::optional<unsigned> Opcode; | ||
VPValue *Op = getOperand(0); | ||
VPValue *Op = getOperand(1); | ||
VPRecipeBase *OpR = Op->getDefiningRecipe(); | ||
|
||
// If the partial reduction is predicated, a select will be operand 0 | ||
if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) { | ||
// If the partial reduction is predicated, a select will be operand 1 | ||
if (match(Op, m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) { | ||
OpR = Op->getDefiningRecipe(); | ||
} | ||
|
||
|
@@ -2841,11 +2842,18 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, | |
cast<VPReductionRecipe>(ExpressionRecipes.back())->getRecurrenceKind()); | ||
switch (ExpressionType) { | ||
case ExpressionTypes::ExtendedReduction: { | ||
unsigned Opcode = RecurrenceDescriptor::getOpcode( | ||
cast<VPReductionRecipe>(ExpressionRecipes[1])->getRecurrenceKind()); | ||
auto *ExtR = cast<VPWidenCastRecipe>(ExpressionRecipes[0]); | ||
if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) { | ||
return Ctx.TTI.getPartialReductionCost( | ||
Opcode, Ctx.Types.inferScalarType(getOperand(0)), nullptr, RedTy, VF, | ||
TargetTransformInfo::getPartialReductionExtendKind(ExtR->getOpcode()), | ||
TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind); | ||
} | ||
return Ctx.TTI.getExtendedReductionCost( | ||
Opcode, | ||
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() == | ||
Instruction::ZExt, | ||
RedTy, SrcVecTy, std::nullopt, Ctx.CostKind); | ||
Opcode, ExtR->getOpcode() == Instruction::ZExt, RedTy, SrcVecTy, | ||
std::nullopt, Ctx.CostKind); | ||
} | ||
case ExpressionTypes::MulAccReduction: | ||
return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy, | ||
|
@@ -2856,6 +2864,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, | |
Opcode = Instruction::Sub; | ||
LLVM_FALLTHROUGH; | ||
case ExpressionTypes::ExtMulAccReduction: { | ||
if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) { | ||
auto *Ext0R = cast<VPWidenCastRecipe>(ExpressionRecipes[0]); | ||
auto *Ext1R = cast<VPWidenCastRecipe>(ExpressionRecipes[1]); | ||
Comment on lines
+2868
to
+2869
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work as expected for all test on current main? I think at least in some cases one of the operands may be a constant live-in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah the matching function in |
||
auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]); | ||
return Ctx.TTI.getPartialReductionCost( | ||
Opcode, Ctx.Types.inferScalarType(getOperand(0)), | ||
Ctx.Types.inferScalarType(getOperand(1)), RedTy, VF, | ||
TargetTransformInfo::getPartialReductionExtendKind( | ||
Ext0R->getOpcode()), | ||
TargetTransformInfo::getPartialReductionExtendKind( | ||
Ext1R->getOpcode()), | ||
Mul->getOpcode(), Ctx.CostKind); | ||
} | ||
return Ctx.TTI.getMulAccReductionCost( | ||
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() == | ||
Instruction::ZExt, | ||
|
@@ -2888,12 +2909,13 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, | |
O << " = "; | ||
auto *Red = cast<VPReductionRecipe>(ExpressionRecipes.back()); | ||
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()); | ||
bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red); | ||
|
||
switch (ExpressionType) { | ||
case ExpressionTypes::ExtendedReduction: { | ||
getOperand(1)->printAsOperand(O, SlotTracker); | ||
O << " +"; | ||
O << " reduce." << Instruction::getOpcodeName(Opcode) << " ("; | ||
O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; | ||
O << Instruction::getOpcodeName(Opcode) << " ("; | ||
getOperand(0)->printAsOperand(O, SlotTracker); | ||
Red->printFlags(O); | ||
|
||
|
@@ -2909,8 +2931,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, | |
} | ||
case ExpressionTypes::ExtNegatedMulAccReduction: { | ||
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); | ||
O << " + reduce." | ||
<< Instruction::getOpcodeName( | ||
O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; | ||
O << Instruction::getOpcodeName( | ||
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind())) | ||
<< " (sub (0, mul"; | ||
auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]); | ||
|
@@ -2934,9 +2956,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, | |
case ExpressionTypes::MulAccReduction: | ||
case ExpressionTypes::ExtMulAccReduction: { | ||
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); | ||
O << " + "; | ||
O << "reduce." | ||
<< Instruction::getOpcodeName( | ||
O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; | ||
O << Instruction::getOpcodeName( | ||
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind())) | ||
<< " ("; | ||
O << "mul"; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this was missed before and only now is tested?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's right.