Skip to content

Commit 8dfd2b7

Browse files
Have certain mask simplification operations happen earlier than morph (#117101)
1 parent 0e04707 commit 8dfd2b7

File tree

5 files changed

+257
-335
lines changed

5 files changed

+257
-335
lines changed

src/coreclr/jit/compiler.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6696,15 +6696,6 @@ class Compiler
66966696
GenTree* fgMorphHWIntrinsicOptional(GenTreeHWIntrinsic* tree);
66976697
GenTree* fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node);
66986698
GenTree* fgOptimizeHWIntrinsicAssociative(GenTreeHWIntrinsic* node);
6699-
#if defined(FEATURE_MASKED_HW_INTRINSICS)
6700-
GenTreeHWIntrinsic* fgOptimizeForMaskedIntrinsic(GenTreeHWIntrinsic* node);
6701-
#endif // FEATURE_MASKED_HW_INTRINSICS
6702-
#ifdef TARGET_ARM64
6703-
bool canMorphVectorOperandToMask(GenTree* node);
6704-
bool canMorphAllVectorOperandsToMasks(GenTreeHWIntrinsic* node);
6705-
GenTree* doMorphVectorOperandToMask(GenTree* node, GenTreeHWIntrinsic* parent);
6706-
GenTreeHWIntrinsic* fgMorphTryUseAllMaskVariant(GenTreeHWIntrinsic* node);
6707-
#endif // TARGET_ARM64
67086699
#endif // FEATURE_HW_INTRINSICS
67096700
GenTree* fgOptimizeCommutativeArithmetic(GenTreeOp* tree);
67106701
GenTree* fgOptimizeRelationalComparisonWithCasts(GenTreeOp* cmp);

src/coreclr/jit/gentree.cpp

Lines changed: 208 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32112,10 +32112,8 @@ bool GenTree::CanDivOrModPossiblyOverflow(Compiler* comp) const
3211232112
#if defined(FEATURE_HW_INTRINSICS)
3211332113
GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3211432114
{
32115-
if (!opts.Tier0OptimizationEnabled())
32116-
{
32117-
return tree;
32118-
}
32115+
assert(!optValnumCSE_phase);
32116+
assert(opts.Tier0OptimizationEnabled());
3211932117

3212032118
NamedIntrinsic ni = tree->GetHWIntrinsicId();
3212132119
var_types retType = tree->TypeGet();
@@ -32254,6 +32252,126 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3225432252
// We shouldn't find AND_NOT nodes since it should only be produced in lowering
3225532253
assert(oper != GT_AND_NOT);
3225632254

32255+
#if defined(FEATURE_MASKED_HW_INTRINSICS) && defined(TARGET_XARCH)
32256+
if (GenTreeHWIntrinsic::OperIsBitwiseHWIntrinsic(oper))
32257+
{
32258+
// Comparisons that produce masks lead to more verbose trees than
32259+
// necessary in many scenarios due to requiring a CvtMaskToVector
32260+
// node to be inserted over them and this can block various opts
32261+
// that are dependent on tree height and similar. So we want to
32262+
// fold the unnecessary back and forth conversions away where possible.
32263+
32264+
genTreeOps effectiveOper = oper;
32265+
GenTree* actualOp2 = op2;
32266+
32267+
if (oper == GT_NOT)
32268+
{
32269+
assert(op2 == nullptr);
32270+
op2 = op1;
32271+
}
32272+
32273+
// We need both operands to be ConvertMaskToVector in
32274+
// order to optimize this to a direct mask operation
32275+
32276+
if (op1->OperIsConvertMaskToVector())
32277+
{
32278+
if (!op2->OperIsHWIntrinsic())
32279+
{
32280+
if ((oper == GT_XOR) && op2->IsVectorAllBitsSet())
32281+
{
32282+
// We want to explicitly recognize op1 ^ AllBitsSet as
32283+
// some platforms don't have direct support for ~op1
32284+
32285+
effectiveOper = GT_NOT;
32286+
op2 = op1;
32287+
}
32288+
}
32289+
32290+
if (op2->OperIsConvertMaskToVector())
32291+
{
32292+
GenTreeHWIntrinsic* cvtOp1 = op1->AsHWIntrinsic();
32293+
GenTreeHWIntrinsic* cvtOp2 = op2->AsHWIntrinsic();
32294+
32295+
unsigned simdBaseTypeSize = genTypeSize(simdBaseType);
32296+
32297+
if ((genTypeSize(cvtOp1->GetSimdBaseType()) == simdBaseTypeSize) &&
32298+
(genTypeSize(cvtOp2->GetSimdBaseType()) == simdBaseTypeSize))
32299+
{
32300+
// We need both operands to be the same kind of mask; otherwise
32301+
// the bitwise operation can differ in how it performs
32302+
32303+
NamedIntrinsic maskIntrinsicId = NI_Illegal;
32304+
32305+
switch (effectiveOper)
32306+
{
32307+
case GT_AND:
32308+
{
32309+
maskIntrinsicId = NI_AVX512_AndMask;
32310+
break;
32311+
}
32312+
32313+
case GT_NOT:
32314+
{
32315+
maskIntrinsicId = NI_AVX512_NotMask;
32316+
break;
32317+
}
32318+
32319+
case GT_OR:
32320+
{
32321+
maskIntrinsicId = NI_AVX512_OrMask;
32322+
break;
32323+
}
32324+
32325+
case GT_XOR:
32326+
{
32327+
maskIntrinsicId = NI_AVX512_XorMask;
32328+
break;
32329+
}
32330+
32331+
default:
32332+
{
32333+
unreached();
32334+
}
32335+
}
32336+
32337+
assert(maskIntrinsicId != NI_Illegal);
32338+
32339+
if (effectiveOper == oper)
32340+
{
32341+
tree->ChangeHWIntrinsicId(maskIntrinsicId);
32342+
tree->Op(1) = cvtOp1->Op(1);
32343+
}
32344+
else
32345+
{
32346+
assert(effectiveOper == GT_NOT);
32347+
tree->ResetHWIntrinsicId(maskIntrinsicId, this, cvtOp1->Op(1));
32348+
tree->gtFlags &= ~GTF_REVERSE_OPS;
32349+
}
32350+
32351+
tree->gtType = TYP_MASK;
32352+
DEBUG_DESTROY_NODE(op1);
32353+
32354+
if (effectiveOper != GT_NOT)
32355+
{
32356+
tree->Op(2) = cvtOp2->Op(1);
32357+
}
32358+
32359+
if (actualOp2 != nullptr)
32360+
{
32361+
DEBUG_DESTROY_NODE(actualOp2);
32362+
}
32363+
tree->SetMorphed(this);
32364+
32365+
tree = gtNewSimdCvtMaskToVectorNode(retType, tree, simdBaseJitType, simdSize)->AsHWIntrinsic();
32366+
tree->SetMorphed(this);
32367+
32368+
return tree;
32369+
}
32370+
}
32371+
}
32372+
}
32373+
#endif // FEATURE_MASKED_HW_INTRINSICS && TARGET_XARCH
32374+
3225732375
switch (ni)
3225832376
{
3225932377
// There's certain IR simplifications that are possible and which
@@ -32830,10 +32948,28 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3283032948
oper = GT_NONE;
3283132949
}
3283232950

32951+
// For mask nodes in particular, the foldings below are done under the presumption
32952+
// that we only produce something like `AddMask(op1, op2)` if op1 and op2 are compatible
32953+
// masks. On xarch, for example, this means that it'd be adding 8, 16, 32, or 64-bits
32954+
// together with the same size. We wouldn't ever encounter something like an 8 and 16 bit
32955+
// masks being added. This ensures that we don't end up with a case where folding would
32956+
// cause a different result to be produced, such as because the remaining upper bits are
32957+
// no longer zeroed.
32958+
3283332959
switch (oper)
3283432960
{
3283532961
case GT_ADD:
3283632962
{
32963+
if (varTypeIsMask(retType))
32964+
{
32965+
// Handle `x + 0 == x` and `0 + x == x`
32966+
if (cnsNode->IsMaskZero())
32967+
{
32968+
resultNode = otherNode;
32969+
}
32970+
break;
32971+
}
32972+
3283732973
if (varTypeIsFloating(simdBaseType))
3283832974
{
3283932975
// Handle `x + NaN == NaN` and `NaN + x == NaN`
@@ -32867,6 +33003,23 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3286733003

3286833004
case GT_AND:
3286933005
{
33006+
if (varTypeIsMask(retType))
33007+
{
33008+
// Handle `x & 0 == 0` and `0 & x == 0`
33009+
if (cnsNode->IsMaskZero())
33010+
{
33011+
resultNode = otherNode;
33012+
break;
33013+
}
33014+
33015+
// Handle `x & AllBitsSet == x` and `AllBitsSet & x == x`
33016+
if (cnsNode->IsMaskAllBitsSet())
33017+
{
33018+
resultNode = otherNode;
33019+
}
33020+
break;
33021+
}
33022+
3287033023
// Handle `x & 0 == 0` and `0 & x == 0`
3287133024
if (cnsNode->IsVectorZero())
3287233025
{
@@ -33100,6 +33253,23 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3310033253

3310133254
case GT_OR:
3310233255
{
33256+
if (varTypeIsMask(retType))
33257+
{
33258+
// Handle `x | 0 == x` and `0 | x == x`
33259+
if (cnsNode->IsMaskZero())
33260+
{
33261+
resultNode = otherNode;
33262+
break;
33263+
}
33264+
33265+
// Handle `x | AllBitsSet == AllBitsSet` and `AllBitsSet | x == AllBitsSet`
33266+
if (cnsNode->IsMaskAllBitsSet())
33267+
{
33268+
resultNode = gtWrapWithSideEffects(cnsNode, otherNode, GTF_ALL_EFFECT);
33269+
}
33270+
break;
33271+
}
33272+
3310333273
// Handle `x | 0 == x` and `0 | x == x`
3310433274
if (cnsNode->IsVectorZero())
3310533275
{
@@ -33127,6 +33297,27 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3312733297
// Handle `x >> 0 == x` and `0 >> x == 0`
3312833298
// Handle `x >>> 0 == x` and `0 >>> x == 0`
3312933299

33300+
if (varTypeIsMask(retType))
33301+
{
33302+
if (cnsNode->IsMaskZero())
33303+
{
33304+
if (cnsNode == op2)
33305+
{
33306+
resultNode = otherNode;
33307+
}
33308+
else
33309+
{
33310+
resultNode = gtWrapWithSideEffects(cnsNode, otherNode, GTF_ALL_EFFECT);
33311+
}
33312+
}
33313+
else if (cnsNode->IsIntegralConst(0))
33314+
{
33315+
assert(cnsNode == op2);
33316+
resultNode = otherNode;
33317+
}
33318+
break;
33319+
}
33320+
3313033321
if (cnsNode->IsVectorZero())
3313133322
{
3313233323
if (cnsNode == op2)
@@ -33172,7 +33363,17 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3317233363

3317333364
case GT_XOR:
3317433365
{
33175-
// Handle `x | 0 == x` and `0 | x == x`
33366+
if (varTypeIsMask(retType))
33367+
{
33368+
// Handle `x ^ 0 == x` and `0 ^ x == x`
33369+
if (cnsNode->IsMaskZero())
33370+
{
33371+
resultNode = otherNode;
33372+
}
33373+
break;
33374+
}
33375+
33376+
// Handle `x ^ 0 == x` and `0 ^ x == x`
3317633377
if (cnsNode->IsVectorZero())
3317733378
{
3317833379
resultNode = otherNode;
@@ -33341,7 +33542,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3334133542
}
3334233543
else
3334333544
{
33344-
assert(!op1->IsTrueMask(simdBaseType) && !op1->IsFalseMask());
33545+
assert(!op1->IsTrueMask(simdBaseType) && !op1->IsMaskZero());
3334533546
}
3334633547
#endif
3334733548

@@ -33359,7 +33560,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3335933560
return op2;
3336033561
}
3336133562

33362-
if (op1->IsVectorZero() || op1->IsFalseMask())
33563+
if (op1->IsVectorZero() || op1->IsMaskZero())
3336333564
{
3336433565
return gtWrapWithSideEffects(op3, op2, GTF_ALL_EFFECT);
3336533566
}

src/coreclr/jit/gentree.h

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,8 +1803,9 @@ struct GenTree
18031803
inline bool IsVectorCreate() const;
18041804
inline bool IsVectorAllBitsSet() const;
18051805
inline bool IsVectorBroadcast(var_types simdBaseType) const;
1806+
inline bool IsMaskZero() const;
1807+
inline bool IsMaskAllBitsSet() const;
18061808
inline bool IsTrueMask(var_types simdBaseType) const;
1807-
inline bool IsFalseMask() const;
18081809

18091810
inline uint64_t GetIntegralVectorConstElement(size_t index, var_types simdBaseType);
18101811

@@ -9629,6 +9630,42 @@ inline bool GenTree::IsVectorBroadcast(var_types simdBaseType) const
96299630
return false;
96309631
}
96319632

9633+
//-------------------------------------------------------------------
9634+
// IsMaskZero: returns true if this node is a mask constant with all bits zero.
9635+
//
9636+
// Returns:
9637+
// True if this node is a mask constant with all bits zero
9638+
//
9639+
inline bool GenTree::IsMaskZero() const
9640+
{
9641+
#if defined(FEATURE_MASKED_HW_INTRINSICS)
9642+
if (IsCnsMsk())
9643+
{
9644+
return AsMskCon()->IsZero();
9645+
}
9646+
#endif // FEATURE_MASKED_HW_INTRINSICS
9647+
9648+
return false;
9649+
}
9650+
9651+
//-------------------------------------------------------------------
9652+
// IsMaskAllBitsSet: returns true if this node is a mask constant with all bits set.
9653+
//
9654+
// Returns:
9655+
// True if this node is a mask constant with all bits set
9656+
//
9657+
inline bool GenTree::IsMaskAllBitsSet() const
9658+
{
9659+
#if defined(FEATURE_MASKED_HW_INTRINSICS)
9660+
if (IsCnsMsk())
9661+
{
9662+
return AsMskCon()->IsAllBitsSet();
9663+
}
9664+
#endif // FEATURE_MASKED_HW_INTRINSICS
9665+
9666+
return false;
9667+
}
9668+
96329669
//------------------------------------------------------------------------
96339670
// IsTrueMask: Is the given node a true mask
96349671
//
@@ -9655,23 +9692,6 @@ inline bool GenTree::IsTrueMask(var_types simdBaseType) const
96559692
return false;
96569693
}
96579694

9658-
//------------------------------------------------------------------------
9659-
// IsFalseMask: Is the given node a false mask
9660-
//
9661-
// Returns true if the node is a false mask, ie all zeros
9662-
//
9663-
inline bool GenTree::IsFalseMask() const
9664-
{
9665-
#ifdef TARGET_ARM64
9666-
if (IsCnsMsk())
9667-
{
9668-
return AsMskCon()->IsZero();
9669-
}
9670-
#endif
9671-
9672-
return false;
9673-
}
9674-
96759695
//-------------------------------------------------------------------
96769696
// GetIntegralVectorConstElement: Gets the value of a given element in an integral vector constant
96779697
//

src/coreclr/jit/lowerarmarch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3942,7 +3942,7 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
39423942
GenTree* op3 = intrin.op3;
39433943

39443944
// Handle op1
3945-
if (op1->IsFalseMask())
3945+
if (op1->IsMaskZero())
39463946
{
39473947
// When we are merging with zero, we can specialize
39483948
// and avoid instantiating the vector constant.

0 commit comments

Comments
 (0)