Skip to content

Commit c8dc585

Browse files
authored
[μKernels]: Fix the broken ARM FP32 micro-kernel lowering (#1087)
This `patch` fixes the Issue: #1086 broken ARM FP32 micro-kernel lowering.
1 parent c630b35 commit c8dc585

File tree

5 files changed

+51
-23
lines changed

5 files changed

+51
-23
lines changed

benchmarks/config/base/base.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@
5252
},
5353
"gemm_fp32_mlir_vector_sve": {
5454
"type": "IR-GEN",
55-
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
55+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=512 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
5656
"environment": {},
57-
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,32,1'" ],
57+
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ],
5858
"extensions": ["asimd"]
5959
},
6060
"gemm_bf16_vnni_dp2_mlir": {
@@ -129,9 +129,9 @@
129129
},
130130
"mlp_fp32_mlir_vector_sve": {
131131
"type": "IR-GEN",
132-
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
132+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=512 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
133133
"environment": {},
134-
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,32,1'" ],
134+
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ],
135135
"extensions": ["asimd"]
136136
},
137137
"mlp_bf16_vnni_dp2_mlir": {

benchmarks/config/omp/mlir-fp32-vector-to-kernel.json

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,28 +189,28 @@
189189
"gemm_fp32_mlir_vector_kernel_32_sve": {
190190
"fp32_3x1024_omp_2_mlir": {
191191
"type": "IR-GEN",
192-
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
192+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=512 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
193193
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
194194
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ],
195195
"extensions": [ "asimd" ]
196196
},
197197
"fp32_3x1024_omp_4_mlir": {
198198
"type": "IR-GEN",
199-
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
199+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=512 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
200200
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
201201
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ],
202202
"extensions": [ "asimd" ]
203203
},
204204
"fp32_3x1024_omp_8_mlir": {
205205
"type": "IR-GEN",
206-
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
206+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=512 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
207207
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
208208
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ],
209209
"extensions": [ "asimd" ]
210210
},
211211
"fp32_3x1024_omp_16_mlir": {
212212
"type": "IR-GEN",
213-
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
213+
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=512 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
214214
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
215215
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ],
216216
"extensions": [ "asimd" ]
@@ -220,28 +220,28 @@
220220
"mlp_fp32_mlir_vector_kernel_32_sve": {
221221
"fp32_3x1024_omp_2_mlir": {
222222
"type": "IR-GEN",
223-
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
223+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=512 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
224224
"environment": { "OMP_NUM_THREADS": "2", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
225225
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,16 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ],
226226
"extensions": [ "asimd" ]
227227
},
228228
"fp32_3x1024_omp_4_mlir": {
229229
"type": "IR-GEN",
230-
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
230+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=512 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
231231
"environment": { "OMP_NUM_THREADS": "4", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
232232
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=8,8 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ],
233233
"extensions": [ "asimd" ]
234234
},
235235
"fp32_3x1024_omp_8_mlir": {
236236
"type": "IR-GEN",
237-
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
237+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=512 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
238238
"environment": { "OMP_NUM_THREADS": "8", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
239239
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=4,8 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ],
240240
"extensions": [ "asimd" ]
241241
},
242242
"fp32_3x1024_omp_16_mlir": {
243243
"type": "IR-GEN",
244-
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
244+
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=512 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
245245
"environment": { "OMP_NUM_THREADS": "16", "KMP_AFFINITY": "granularity=fine,verbose,compact,1,0" },
246246
"flags": [ "-n", "100", "-run-args='--def-parallel --parallel-task-grid=2,8 --vector-to-kernels --registerBlocking=4,32,1 -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256'" ],
247247
"extensions": [ "asimd" ]

include/TPP/Transforms/Utils/VNNIUtils.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,22 @@ enum class VnniOperandRank {
3636
BRGEMM_INS = 4,
3737
BRGEMM_OUTS = 3
3838
};
39-
// Returns True if the current architecture supports AVX2 instructions.
39+
40+
// Returns true if the current architecture supports AVX2 instructions.
4041
bool hasAVX2();
4142

42-
// Returns True if the current architecture supports AVX512 instructions.
43+
// Returns true if the current architecture supports AVX512 instructions.
4344
bool hasAVX512();
4445

45-
// Returns True if the current architecture supports AMX instructions.
46+
// Returns true if the current architecture supports AMX instructions.
4647
bool hasAMX();
4748

49+
// Returns true if the current architecture supports SVE-256 instructions.
50+
bool hasSVE256();
51+
52+
// Returns true if the current architecture supports SVE-512 instructions.
53+
bool hasSVE512();
54+
4855
// Returns the current target architecture name
4956
std::string getTargetArchName();
5057

lib/TPP/Transforms/Utils/VNNIUtils.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,36 @@ namespace mlir {
2323
namespace vnni {
2424
namespace utils {
2525

26-
// Returns True if the current architecture supports AMX instructions.
26+
// Returns true if the current architecture supports AMX instructions.
2727
bool hasAMX() {
2828
return (libxsmm_get_target_archid() >= LIBXSMM_X86_AVX512_SPR) &&
2929
(libxsmm_get_target_archid() < LIBXSMM_X86_ALLFEAT);
3030
}
3131

32-
// Returns True if the current architecture supports AMX instructions.
32+
// Returns true if the current architecture supports AVX2 instructions.
3333
bool hasAVX2() {
3434
return (libxsmm_get_target_archid() >= LIBXSMM_X86_AVX2) &&
3535
(libxsmm_get_target_archid() < LIBXSMM_X86_ALLFEAT);
3636
}
3737

38-
// Returns True if the current architecture supports AMX instructions.
38+
// Returns True if the current architecture supports AVX512 instructions.
3939
bool hasAVX512() {
4040
return (libxsmm_get_target_archid() >= LIBXSMM_X86_AVX512_SKX) &&
4141
(libxsmm_get_target_archid() < LIBXSMM_X86_ALLFEAT);
4242
}
4343

44+
// Returns true if the current architecture supports SVE-256 instructions.
45+
bool hasSVE256() {
46+
return (libxsmm_get_target_archid() >= LIBXSMM_AARCH64_NEOV2) &&
47+
(libxsmm_get_target_archid() <= LIBXSMM_AARCH64_NEOV1);
48+
}
49+
50+
// Returns true if the current architecture supports SVE-512 instructions.
51+
bool hasSVE512() {
52+
return (libxsmm_get_target_archid() >= LIBXSMM_AARCH64_SVE512) &&
53+
(libxsmm_get_target_archid() <= LIBXSMM_AARCH64_A64FX);
54+
}
55+
4456
// Returns the current target architecture name
4557
std::string getTargetArchName() {
4658
if (libxsmm_get_target_archid() == LIBXSMM_X86_AVX2_SRF)

lib/TPP/Transforms/VectorContractToMicroKernels.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,23 +347,32 @@ struct MicroKernelsOp : OpRewritePattern<vector::ContractionOp> {
347347
// We get target architecture and decide on uKernel lowering using flags
348348
bool avx512 = vnni::utils::hasAVX512();
349349
bool avx2 = vnni::utils::hasAVX2();
350+
bool sve256 = vnni::utils::hasSVE256();
351+
bool sve512 = vnni::utils::hasSVE512();
350352

351353
// disable avx512, if target feature is avx2
352354
if (options.targetFeature == "avx2")
353355
avx512 = false;
354356

355-
int64_t sizeFactor = avx512 ? 16 : avx2 ? 8 : 0;
356-
357-
if (sizeFactor == 0)
358-
return rewriter.notifyMatchFailure(
359-
contractOp, "AVX512 or AVX2 required for this pass");
357+
int64_t sizeFactor = (avx512 || sve512) ? 16 : (avx2 || sve256) ? 8 : 0;
360358

361359
bool isF32 = elementType.isF32();
362360
bool isF16 = elementType.isF16();
363361
bool isBF16 = elementType.isBF16();
364362
bool isI8 = elementType.isSignlessInteger(8);
365363

366364
bool isPackedType = isF16 || isBF16 || isI8;
365+
366+
if (sizeFactor == 0)
367+
return rewriter.notifyMatchFailure(
368+
contractOp, "AVX512 or AVX2 or SVE512/256 instruction set is not available or "
369+
"lowering is not available for this target machine.");
370+
371+
if ((sve256 || sve512) && isPackedType)
372+
return rewriter.notifyMatchFailure(
373+
contractOp,
374+
"only FP32 type lowering is supported for AARCH64(ARM) machines.");
375+
367376
int64_t vnniFactor = (isBF16 || isF16) ? 2 : isI8 ? 4 : 1;
368377
bool isSplat = false;
369378

0 commit comments

Comments
 (0)