Skip to content

Commit f8d8a16

Browse files
authored
Brgemm register tiling support for bf16 type (#1005)
This PR extends the `brgemm register tiling` pass to support `bf16` type. The changes: 1) Template the existing pass to execute on `linalg.batch_reduce_matmul` for `fp32` and `linal.generic` for `vnni` opt bf16, 2) Test-cases for `bf16` type.
1 parent cb1e22f commit f8d8a16

File tree

7 files changed

+351
-333
lines changed

7 files changed

+351
-333
lines changed

benchmarks/config/base/base.json

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,21 @@
4040
"type": "IR-GEN",
4141
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
4242
"environment": {},
43-
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32 '" ],
43+
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32,1 '" ],
4444
"extensions": ["avx512.*"]
4545
},
4646
"gemm_fp32_mlir_vector_avx2": {
4747
"type": "IR-GEN",
4848
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
4949
"environment": {},
50-
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,16 '" ],
50+
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,16,1 '" ],
5151
"extensions": ["avx2"]
5252
},
5353
"gemm_fp32_mlir_vector_sve": {
5454
"type": "IR-GEN",
5555
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
5656
"environment": {},
57-
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,32 '" ],
57+
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=4,32,1 '" ],
5858
"extensions": ["asimd"]
5959
},
6060
"gemm_bf16_dp2_mlir": {
@@ -82,21 +82,21 @@
8282
"type": "IR-GEN",
8383
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
8484
"environment": {},
85-
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=8,32 '" ],
85+
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=8,32,1 '" ],
8686
"extensions": ["avx512.*"]
8787
},
8888
"mlp_fp32_mlir_vector_avx2": {
8989
"type": "IR-GEN",
9090
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
9191
"environment": {},
92-
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,16 '" ],
92+
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,16,1 '" ],
9393
"extensions": ["avx2" ]
9494
},
9595
"mlp_fp32_mlir_vector_sve": {
9696
"type": "IR-GEN",
9797
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024 --tiles=32,32,32" ],
9898
"environment": {},
99-
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,32 '" ],
99+
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=4,32,1 '" ],
100100
"extensions": ["asimd"]
101101
},
102102
"mlp_bf16_dp2_mlir": {
@@ -127,7 +127,7 @@
127127
"type": "IR-GEN",
128128
"benchmark": [ "mlir-gen", "--kernel=const --float-type=f32 --batch=256 --layers=1024,1024,1024,1024" ],
129129
"environment": {},
130-
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32 '" ],
130+
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32,1 '" ],
131131
"extensions": [ "avx512.*" ]
132132
},
133133
"fp32_3x1024_args_mlir": {
@@ -141,7 +141,7 @@
141141
"type": "IR-GEN",
142142
"benchmark": [ "mlir-gen", "--kernel=args --float-type=f32 --batch=256 --layers=1024,1024,1024,1024" ],
143143
"environment": {},
144-
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32 '" ],
144+
"flags": [ "-n", "100", "-run-args='--vector-to-kernels --registerBlocking=8,32,1 '" ],
145145
"extensions": [ "avx512.*" ]
146146
},
147147
"bf16_3x1024_const_mlir": {
@@ -172,7 +172,7 @@
172172
"type": "IR-GEN",
173173
"benchmark": [ "mlir-gen", "--kernel=const --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024" ],
174174
"environment": {},
175-
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=8,32 '" ],
175+
"flags": [ "-n", "100", "-run-args='--def-parallel --vector-to-kernels --registerBlocking=8,32,1 '" ],
176176
"extensions": [ "avx512.*" ]
177177
},
178178
"fp32_3x1024_args_mlir": {
@@ -186,7 +186,7 @@
186186
"type": "IR-GEN",
187187
"benchmark": [ "mlir-gen", "--kernel=args --bias --relu --float-type=f32 --batch=256 --layers=1024,1024,1024,1024" ],
188188
"environment": {},
189-
"flags": [ "-n", "100", "-run-args=' --def-parallel --vector-to-kernels --registerBlocking=8,32 '" ],
189+
"flags": [ "-n", "100", "-run-args=' --def-parallel --vector-to-kernels --registerBlocking=8,32,1 '" ],
190190
"extensions": [ "avx512.*" ]
191191
},
192192
"bf16_3x1024_const_mlir": {

0 commit comments

Comments
 (0)