Skip to content

Commit 09cbb85

Browse files
committed
Arm64: re-enable use of predicate variants
Fixes dotnet#101970 Predicate variants were implemented in dotnet#114438 and then turned off in dotnet#115566. The code was then removed in dotnet#117101 when the AMD64 version was moved to from morph to folding. This is a simple rework of that code. Replaces dotnet#116854
1 parent 4cc3020 commit 09cbb85

File tree

2 files changed

+172
-58
lines changed

2 files changed

+172
-58
lines changed

src/coreclr/jit/gentree.cpp

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32252,7 +32252,8 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3225232252
// We shouldn't find AND_NOT nodes since it should only be produced in lowering
3225332253
assert(oper != GT_AND_NOT);
3225432254

32255-
#if defined(FEATURE_MASKED_HW_INTRINSICS) && defined(TARGET_XARCH)
32255+
#if defined(FEATURE_MASKED_HW_INTRINSICS)
32256+
#if defined(TARGET_XARCH)
3225632257
if (GenTreeHWIntrinsic::OperIsBitwiseHWIntrinsic(oper))
3225732258
{
3225832259
// Comparisons that produce masks lead to more verbose trees than
@@ -32370,7 +32371,66 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3237032371
}
3237132372
}
3237232373
}
32373-
#endif // FEATURE_MASKED_HW_INTRINSICS && TARGET_XARCH
32374+
#elif defined(TARGET_ARM64)
32375+
// Check if the tree can be folded into a mask variant
32376+
if (HWIntrinsicInfo::HasAllMaskVariant(tree->GetHWIntrinsicId()))
32377+
{
32378+
NamedIntrinsic maskVariant = HWIntrinsicInfo::GetMaskVariant(tree->GetHWIntrinsicId());
32379+
32380+
assert(opCount == (size_t)HWIntrinsicInfo::lookupNumArgs(maskVariant));
32381+
32382+
// Check all operands are valid
32383+
bool canFold = true;
32384+
for (size_t i = 1; i <= opCount && canFold; i++)
32385+
{
32386+
canFold &=
32387+
(varTypeIsMask(tree->Op(i)) || tree->Op(i)->OperIsConvertMaskToVector() || tree->Op(i)->IsVectorZero());
32388+
}
32389+
32390+
if (canFold)
32391+
{
32392+
// Convert all the operands to masks
32393+
for (size_t i = 1; i <= opCount; i++)
32394+
{
32395+
if (tree->Op(i)->OperIsConvertMaskToVector())
32396+
{
32397+
// Replace with op1.
32398+
tree->Op(i) = tree->Op(i)->AsHWIntrinsic()->Op(1);
32399+
}
32400+
else if (tree->Op(i)->IsVectorZero())
32401+
{
32402+
// Replace the vector of zeroes with a mask of zeroes.
32403+
tree->Op(i) = gtNewSimdFalseMaskByteNode();
32404+
tree->Op(i)->SetMorphed(this);
32405+
}
32406+
assert(varTypeIsMask(tree->Op(i)));
32407+
}
32408+
32409+
// Switch to the mask variant
32410+
switch (opCount)
32411+
{
32412+
case 1:
32413+
tree->ResetHWIntrinsicId(maskVariant, tree->Op(1));
32414+
break;
32415+
case 2:
32416+
tree->ResetHWIntrinsicId(maskVariant, tree->Op(1), tree->Op(2));
32417+
break;
32418+
case 3:
32419+
tree->ResetHWIntrinsicId(maskVariant, this, tree->Op(1), tree->Op(2), tree->Op(3));
32420+
break;
32421+
default:
32422+
unreached();
32423+
}
32424+
32425+
tree->gtType = TYP_MASK;
32426+
tree->SetMorphed(this);
32427+
tree = gtNewSimdCvtMaskToVectorNode(retType, tree, simdBaseJitType, simdSize)->AsHWIntrinsic();
32428+
tree->SetMorphed(this);
32429+
return tree;
32430+
}
32431+
}
32432+
#endif // TARGET_XARCH
32433+
#endif // FEATURE_MASKED_HW_INTRINSICS
3237432434

3237532435
switch (ni)
3237632436
{
@@ -33605,7 +33665,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3360533665
assert(op2->gtType == TYP_SIMD16);
3360633666
assert(op3->gtType == TYP_SIMD16);
3360733667

33608-
simd16_t op1SimdVal;
33668+
simd16_t op1SimdVal = {};
3360933669
EvaluateSimdCvtMaskToVector<simd16_t>(simdBaseType, &op1SimdVal, op1->AsMskCon()->gtSimdMaskVal);
3361033670

3361133671
// op2 = op2 & op1

src/tests/JIT/opt/SVE/PredicateInstructions.cs

Lines changed: 109 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,110 +17,164 @@ public static void TestPredicateInstructions()
1717
{
1818
if (Sve.IsSupported)
1919
{
20-
ZipLow();
21-
ZipHigh();
22-
UnzipOdd();
23-
UnzipEven();
24-
TransposeOdd();
25-
TransposeEven();
26-
ReverseElement();
27-
And();
28-
BitwiseClear();
29-
Xor();
30-
Or();
31-
ConditionalSelect();
20+
Vector<sbyte> vecsb = Vector.Create<sbyte>(2);
21+
Vector<short> vecs = Vector.Create<short>(2);
22+
Vector<ushort> vecus = Vector.Create<ushort>(2);
23+
Vector<int> veci = Vector.Create<int>(3);
24+
Vector<uint> vecui = Vector.Create<uint>(5);
25+
Vector<long> vecl = Vector.Create<long>(7);
26+
27+
ZipLowMask(vecs, vecs);
28+
ZipHighMask(vecui, vecui);
29+
UnzipOddMask(vecs, vecs);
30+
UnzipEvenMask(vecsb, vecsb);
31+
TransposeEvenMask(vecl, vecl);
32+
TransposeOddMask(vecs, vecs);
33+
ReverseElementMask(vecs, vecs);
34+
AndMask(vecs, vecs);
35+
BitwiseClearMask(vecs, vecs);
36+
XorMask(veci, veci);
37+
OrMask(vecs, vecs);
38+
ConditionalSelectMask(veci, veci, veci);
39+
40+
UnzipEvenZipLowMask(vecs, vecs);
41+
TransposeEvenAndMask(vecs, vecs, vecs);
42+
3243
}
3344
}
3445

46+
// These should use the predicate variants.
47+
// Sve intrinsics that return masks (Compare) or use mask arguments (CreateBreakAfterMask) are used
48+
// to ensure masks are used.
49+
50+
3551
[MethodImpl(MethodImplOptions.NoInlining)]
36-
static Vector<short> ZipLow()
52+
static Vector<short> ZipLowMask(Vector<short> a, Vector<short> b)
3753
{
38-
return Sve.ZipLow(Vector<short>.Zero, Sve.CreateTrueMaskInt16());
54+
//ARM64-FULL-LINE: zip1 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h
55+
return Sve.ZipLow(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b));
3956
}
4057

4158
[MethodImpl(MethodImplOptions.NoInlining)]
42-
static Vector<uint> ZipHigh()
59+
static Vector<uint> ZipHighMask(Vector<uint> a, Vector<uint> b)
4360
{
44-
return Sve.ZipHigh(Sve.CreateTrueMaskUInt32(), Sve.CreateTrueMaskUInt32());
61+
//ARM64-FULL-LINE: zip2 {{p[0-9]+}}.s, {{p[0-9]+}}.s, {{p[0-9]+}}.s
62+
return Sve.CreateBreakAfterMask(Sve.ZipHigh(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), Sve.CreateTrueMaskUInt32());
4563
}
4664

4765
[MethodImpl(MethodImplOptions.NoInlining)]
48-
static Vector<sbyte> UnzipEven()
66+
static Vector<sbyte> UnzipEvenMask(Vector<sbyte> a, Vector<sbyte> b)
4967
{
50-
return Sve.UnzipEven(Sve.CreateTrueMaskSByte(), Vector<sbyte>.Zero);
68+
//ARM64-FULL-LINE: uzp1 {{p[0-9]+}}.b, {{p[0-9]+}}.b, {{p[0-9]+}}.b
69+
return Sve.UnzipEven(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b));
5170
}
5271

5372
[MethodImpl(MethodImplOptions.NoInlining)]
54-
static Vector<short> UnzipOdd()
73+
static Vector<short> UnzipOddMask(Vector<short> a, Vector<short> b)
5574
{
56-
return Sve.UnzipOdd(Sve.CreateTrueMaskInt16(), Sve.CreateFalseMaskInt16());
75+
//ARM64-FULL-LINE: uzp2 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h
76+
return Sve.CreateBreakAfterMask(Sve.UnzipOdd(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), Sve.CreateTrueMaskInt16());
5777
}
5878

5979
[MethodImpl(MethodImplOptions.NoInlining)]
60-
static Vector<long> TransposeEven()
80+
static Vector<long> TransposeEvenMask(Vector<long> a, Vector<long> b)
6181
{
62-
return Sve.TransposeEven(Sve.CreateFalseMaskInt64(), Sve.CreateTrueMaskInt64());
82+
//ARM64-FULL-LINE: trn1 {{p[0-9]+}}.d, {{p[0-9]+}}.d, {{p[0-9]+}}.d
83+
return Sve.CreateBreakAfterMask(Sve.TransposeEven(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), Sve.CreateFalseMaskInt64());
6384
}
6485

6586
[MethodImpl(MethodImplOptions.NoInlining)]
66-
static Vector<short> TransposeOdd()
87+
static Vector<short> TransposeOddMask(Vector<short> a, Vector<short> b)
6788
{
68-
return Sve.TransposeOdd(Vector<short>.Zero, Sve.CreateTrueMaskInt16());
89+
//ARM64-FULL-LINE: trn2 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h
90+
return Sve.TransposeOdd(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b));
6991
}
7092

7193
[MethodImpl(MethodImplOptions.NoInlining)]
72-
static Vector<short> ReverseElement()
94+
static Vector<short> ReverseElementMask(Vector<short> a, Vector<short> b)
7395
{
74-
return Sve.ReverseElement(Sve.CreateTrueMaskInt16());
96+
//ARM64-FULL-LINE: rev {{p[0-9]+}}.h, {{p[0-9]+}}.h
97+
return Sve.CreateBreakAfterMask(Sve.ReverseElement(Sve.CompareGreaterThan(a, b)), Sve.CreateFalseMaskInt16());
7598
}
7699

77100
[MethodImpl(MethodImplOptions.NoInlining)]
78-
static Vector<short> And()
101+
static Vector<short> AndMask(Vector<short> a, Vector<short> b)
79102
{
80-
return Sve.ConditionalSelect(
81-
Sve.CreateTrueMaskInt16(),
82-
Sve.And(Sve.CreateTrueMaskInt16(), Sve.CreateTrueMaskInt16()),
83-
Vector<short>.Zero
84-
);
103+
//ARM64-FULL-LINE: and {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b
104+
return Sve.CreateBreakAfterMask(
105+
Sve.ConditionalSelect(
106+
Sve.CreateTrueMaskInt16(),
107+
Sve.And(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
108+
Vector<short>.Zero),
109+
Sve.CreateFalseMaskInt16());
85110
}
86111

87112
[MethodImpl(MethodImplOptions.NoInlining)]
88-
static Vector<short> BitwiseClear()
113+
static Vector<short> BitwiseClearMask(Vector<short> a, Vector<short> b)
89114
{
115+
//ARM64-FULL-LINE: bic {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b
90116
return Sve.ConditionalSelect(
91-
Sve.CreateFalseMaskInt16(),
92-
Sve.BitwiseClear(Sve.CreateTrueMaskInt16(), Sve.CreateTrueMaskInt16()),
93-
Vector<short>.Zero
94-
);
117+
Sve.CreateTrueMaskInt16(),
118+
Sve.BitwiseClear(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
119+
Vector<short>.Zero);
95120
}
96121

97122
[MethodImpl(MethodImplOptions.NoInlining)]
98-
static Vector<int> Xor()
123+
static Vector<int> XorMask(Vector<int> a, Vector<int> b)
99124
{
100-
return Sve.ConditionalSelect(
101-
Sve.CreateTrueMaskInt32(),
102-
Sve.Xor(Sve.CreateTrueMaskInt32(), Sve.CreateTrueMaskInt32()),
103-
Vector<int>.Zero
104-
);
125+
//ARM64-FULL-LINE: eor {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b
126+
return Sve.CreateBreakAfterMask(
127+
Sve.ConditionalSelect(
128+
Sve.CreateTrueMaskInt32(),
129+
Sve.Xor(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
130+
Vector<int>.Zero),
131+
Sve.CreateFalseMaskInt32());
105132
}
106133

107134
[MethodImpl(MethodImplOptions.NoInlining)]
108-
static Vector<short> Or()
135+
static Vector<short> OrMask(Vector<short> a, Vector<short> b)
109136
{
137+
//ARM64-FULL-LINE: orr {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b
110138
return Sve.ConditionalSelect(
111-
Sve.CreateTrueMaskInt16(),
112-
Sve.Or(Sve.CreateTrueMaskInt16(), Sve.CreateTrueMaskInt16()),
113-
Vector<short>.Zero
114-
);
139+
Sve.CreateTrueMaskInt16(),
140+
Sve.Or(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
141+
Vector<short>.Zero);
115142
}
116143

117144
[MethodImpl(MethodImplOptions.NoInlining)]
118-
static Vector<int> ConditionalSelect()
145+
static Vector<int> ConditionalSelectMask(Vector<int> v, Vector<int> a, Vector<int> b)
119146
{
120-
return Sve.ConditionalSelect(
121-
Vector<int>.Zero,
122-
Sve.CreateFalseMaskInt32(),
123-
Sve.CreateTrueMaskInt32()
124-
);
147+
// Use a passed in vector for the mask to prevent optimising away the select
148+
//ARM64-FULL-LINE: sel {{p[0-9]+}}.b, {{p[0-9]+}}, {{p[0-9]+}}.b, {{p[0-9]+}}.b
149+
return Sve.CreateBreakAfterMask(
150+
Sve.ConditionalSelect(v, Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
151+
Sve.CreateFalseMaskInt32());
152+
}
153+
154+
// These have multiple uses of the predicate variants
155+
156+
[MethodImpl(MethodImplOptions.NoInlining)]
157+
static Vector<short> UnzipEvenZipLowMask(Vector<short> a, Vector<short> b)
158+
{
159+
//ARM64-FULL-LINE: zip1 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h
160+
//ARM64-FULL-LINE: uzp1 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h
161+
return Sve.CreateBreakAfterMask(
162+
Sve.UnzipEven(
163+
Sve.ZipLow(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
164+
Sve.CompareLessThan(a, b)),
165+
Sve.CreateTrueMaskInt16());
166+
}
167+
168+
[MethodImpl(MethodImplOptions.NoInlining)]
169+
static Vector<short> TransposeEvenAndMask(Vector<short> v, Vector<short> a, Vector<short> b)
170+
{
171+
//ARM64-FULL-LINE: and {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b
172+
//ARM64-FULL-LINE: trn1 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h
173+
return Sve.TransposeEven(
174+
Sve.CompareGreaterThan(a, b),
175+
Sve.ConditionalSelect(
176+
Sve.CreateTrueMaskInt16(),
177+
Sve.And(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
178+
Sve.CompareLessThan(a, b)));
125179
}
126-
}
180+
}

0 commit comments

Comments
 (0)