@@ -35,6 +35,34 @@ module {
35
35
36
36
// -----
37
37
38
+ module {
39
+ func.func @brgemm_tensor_type_tiling (%arg0: tensor <128 x256 x512 xf32 >, %arg1: tensor <128 x512 x256 xf32 >, %arg2: tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 > {
40
+ %0 = linalg.batch_reduce_matmul ins (%arg0 , %arg1 : tensor <128 x256 x512 xf32 >, tensor <128 x512 x256 xf32 >) outs (%arg2 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
41
+ return %0 : tensor <256 x256 xf32 >
42
+ }
43
+ }
44
+
45
+
46
+ // CONF1-LABEL: func.func @brgemm_tensor_type_tiling
47
+ // CONF1-DAG: %[[C0:.+]] = arith.constant 0 : index
48
+ // CONF1-DAG: %[[C256:.+]] = arith.constant 256 : index
49
+ // CONF1-DAG: %[[C8:.+]] = arith.constant 8 : index
50
+ // CONF1-DAG: %[[C32:.+]] = arith.constant 32 : index
51
+ // CONF1-DAG: %[[C128:.+]] = arith.constant 128 : index
52
+ // CONF1-DAG: %[[C1:.+]] = arith.constant 1 : index
53
+ // CONF1-DAG: %[[C512:.+]] = arith.constant 512 : index
54
+ // CONF1: %0 = scf.for %[[I:.+]] = %[[C0]] to %[[C256]] step %[[C8]] iter_args(%arg4 = %arg2) -> (tensor<256x256xf32>) {
55
+ // CONF1-NEXT: %1 = scf.for %[[J:.+]] = %[[C0]] to %[[C256]] step %[[C32]] iter_args(%arg6 = %arg4) -> (tensor<256x256xf32>) {
56
+ // CONF1-NEXT: %2 = scf.for %[[K:.+]] = %[[C0]] to %[[C128]] step %[[C1]] iter_args(%arg8 = %arg6) -> (tensor<256x256xf32>) {
57
+ // CONF1-NEXT: %3 = scf.for %[[L:.+]] = %[[C0]] to %[[C512]] step %[[C1]] iter_args(%arg10 = %arg8) -> (tensor<256x256xf32>) {
58
+ // CONF1-NEXT: %extracted_slice = tensor.extract_slice %arg0[%[[K]], %[[I]], %[[L]]] [1, 8, 1] [1, 1, 1] : tensor<128x256x512xf32> to tensor<1x8x1xf32>
59
+ // CONF1-NEXT: %extracted_slice_0 = tensor.extract_slice %arg1[%[[K]], %[[L]], %[[J]]] [1, 1, 32] [1, 1, 1] : tensor<128x512x256xf32> to tensor<1x1x32xf32>
60
+ // CONF1-NEXT: %extracted_slice_1 = tensor.extract_slice %arg10[%[[I]], %[[J]]] [8, 32] [1, 1] : tensor<256x256xf32> to tensor<8x32xf32>
61
+ // CONF1-NEXT: %4 = linalg.batch_reduce_matmul ins(%extracted_slice, %extracted_slice_0 : tensor<1x8x1xf32>, tensor<1x1x32xf32>) outs(%extracted_slice_1 : tensor<8x32xf32>) -> tensor<8x32xf32>
62
+ // CONF1-NEXT: %inserted_slice = tensor.insert_slice %4 into %arg10[%[[I]], %[[J]]] [8, 32] [1, 1] : tensor<8x32xf32> into tensor<256x256xf32>
63
+
64
+ // -----
65
+
38
66
#map = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d2 , d4 , d1 )>
39
67
#map1 = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d4 , d3 , d1 )>
40
68
#map2 = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d2 , d3 )>
@@ -124,17 +152,48 @@ module {
124
152
125
153
// -----
126
154
155
+
156
+ #map = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d2 , d4 , d1 )>
157
+ #map1 = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d4 , d3 , d1 )>
158
+ #map2 = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d2 , d3 )>
127
159
module {
128
- func.func @brgemm_tensor_type_no_tiling (%arg0: tensor <128 x256 x512 xf32 >, %arg1: tensor <128 x512 x256 xf32 >, %arg2: tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 > {
129
- %0 = linalg.batch_reduce_matmul ins (%arg0 , %arg1 : tensor <128 x256 x512 xf32 >, tensor <128 x512 x256 xf32 >) outs (%arg2 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
130
- return %0 : tensor <256 x256 xf32 >
160
+ func.func @gemm_64tiles_do_tiling_bf16_tensor (%arg0: tensor <4 x16 x64 x64 xbf16 >) -> tensor <4 x16 x64 x64 xbf16 > {
161
+ %cst = arith.constant dense <1.000000e+00 > : tensor <16 x32 x64 x2 xbf16 >
162
+ %cst_0 = arith.constant 0.000000e+00 : bf16
163
+ %0 = bufferization.alloc_tensor () : tensor <4 x16 x64 x64 xbf16 >
164
+ %expanded = tensor.expand_shape %arg0 [[0 ], [1 ], [2 ], [3 , 4 ]] output_shape [4 , 16 , 64 , 32 , 2 ] : tensor <4 x16 x64 x64 xbf16 > into tensor <4 x16 x64 x32 x2 xbf16 >
165
+ %1 = scf.forall (%arg1 , %arg2 ) in (4 , 16 ) shared_outs (%arg3 = %0 ) -> (tensor <4 x16 x64 x64 xbf16 >) {
166
+ %extracted_slice = tensor.extract_slice %arg3 [%arg1 , %arg2 , 0 , 0 ] [1 , 1 , 64 , 64 ] [1 , 1 , 1 , 1 ] : tensor <4 x16 x64 x64 xbf16 > to tensor <64 x64 xbf16 >
167
+ %2 = linalg.fill ins (%cst_0 : bf16 ) outs (%extracted_slice : tensor <64 x64 xbf16 >) -> tensor <64 x64 xbf16 >
168
+ %extracted_slice_1 = tensor.extract_slice %expanded [%arg1 , 0 , 0 , 0 , 0 ] [1 , 16 , 64 , 32 , 2 ] [1 , 1 , 1 , 1 , 1 ] : tensor <4 x16 x64 x32 x2 xbf16 > to tensor <16 x64 x32 x2 xbf16 >
169
+ %3 = linalg.generic {index ing_maps = [#map , #map1 , #map2 ], iterator_types = [" reduction" , " reduction" , " parallel" , " parallel" , " reduction" ]} ins (%extracted_slice_1 , %cst : tensor <16 x64 x32 x2 xbf16 >, tensor <16 x32 x64 x2 xbf16 >) outs (%2 : tensor <64 x64 xbf16 >) {
170
+ ^bb0 (%in: bf16 , %in_2: bf16 , %out: bf16 ):
171
+ %4 = arith.mulf %in , %in_2 : bf16
172
+ %5 = arith.addf %out , %4 : bf16
173
+ linalg.yield %5 : bf16
174
+ } -> tensor <64 x64 xbf16 >
175
+ scf.forall.in_parallel {
176
+ tensor.parallel_insert_slice %3 into %arg3 [%arg1 , %arg2 , 0 , 0 ] [1 , 1 , 64 , 64 ] [1 , 1 , 1 , 1 ] : tensor <64 x64 xbf16 > into tensor <4 x16 x64 x64 xbf16 >
177
+ }
178
+ }
179
+ return %1 : tensor <4 x16 x64 x64 xbf16 >
131
180
}
132
181
}
133
182
134
-
135
- // CONF1-LABEL: func.func @brgemm_tensor_type_no_tiling
136
- // CONF1-NOT: scf.for
137
- // CONF2-NOT: scf.for
183
+ // CONF2-LABEL: func.func @gemm_64tiles_do_tiling_bf16_tensor
184
+ // CONF2-DAG: %[[C1:.+]] = arith.constant 1 : index
185
+ // CONF2-DAG: %[[C32:.+]] = arith.constant 32 : index
186
+ // CONF2-DAG: %[[C64:.+]] = arith.constant 64 : index
187
+ // CONF2-DAG: %[[C16:.+]] = arith.constant 16 : index
188
+ // CONF2-DAG: %[[C0:.+]] = arith.constant 0 : index
189
+ // CONF2: %3 = scf.for %[[I:.+]] = %[[C0]] to %[[C64]] step %[[C32]] iter_args(%arg5 = %2) -> (tensor<64x64xbf16>)
190
+ // CONF2-NEXT: %4 = scf.for %[[J:.+]] = %[[C0]] to %[[C64]] step %[[C32]] iter_args(%arg7 = %arg5) -> (tensor<64x64xbf16>)
191
+ // CONF2-NEXT: %5 = scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%arg9 = %arg7) -> (tensor<64x64xbf16>)
192
+ // CONF2-NEXT: %6 = scf.for %[[L:.+]] = %[[C0]] to %[[C32]] step %[[C16]] iter_args(%arg11 = %arg9) -> (tensor<64x64xbf16>)
193
+ // CONF2-NEXT: %extracted_slice_2 = tensor.extract_slice %extracted_slice_1[%[[K]], %[[I]], %[[L]], 0] [1, 32, 16, 2] [1, 1, 1, 1] : tensor<16x64x32x2xbf16> to tensor<1x32x16x2xbf16>
194
+ // CONF2-NEXT: %extracted_slice_3 = tensor.extract_slice %cst[%[[K]], %[[L]], %[[J]], 0] [1, 16, 32, 2] [1, 1, 1, 1] : tensor<16x32x64x2xbf16> to tensor<1x16x32x2xbf16>
195
+ // CONF2-NEXT: %extracted_slice_4 = tensor.extract_slice %arg11[%[[I]], %[[J]]] [32, 32] [1, 1] : tensor<64x64xbf16> to tensor<32x32xbf16>
196
+ // CONF2-NEXT: %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_2, %extracted_slice_3 : tensor<1x32x16x2xbf16>, tensor<1x16x32x2xbf16>) outs(%extracted_slice_4 : tensor<32x32xbf16>)
138
197
139
198
// -----
140
199
@@ -146,7 +205,44 @@ module {
146
205
}
147
206
}
148
207
149
-
150
208
// CONF1-LABEL: func.func @matmul_no_tiling
151
209
// CONF1-NOT: scf.for
210
+ // CONF2-LABEL: func.func @matmul_no_tiling
152
211
// CONF2-NOT: scf.for
212
+
213
+ // -----
214
+
215
+ func.func @batch_matmul_no_tiling (%arg0: tensor <512 x32 x64 xf32 >, %arg1: tensor <512 x64 x32 xf32 >) -> tensor <512 x32 x32 xf32 > {
216
+ %0 = tensor.empty () : tensor <512 x32 x32 xf32 >
217
+ %1 = linalg.batch_matmul ins (%arg0 , %arg1 : tensor <512 x32 x64 xf32 >, tensor <512 x64 x32 xf32 >)
218
+ outs (%0 : tensor <512 x32 x32 xf32 >) -> tensor <512 x32 x32 xf32 >
219
+ return %1 : tensor <512 x32 x32 xf32 >
220
+ }
221
+
222
+ // CONF1-LABEL: func.func @batch_matmul_no_tiling
223
+ // CONF1-NOT: scf.for
224
+ // CONF2-LABEL: func.func @batch_matmul_no_tiling
225
+ // CONF2-NOT: scf.for
226
+
227
+ // -----
228
+
229
+ #map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
230
+ func.func @generic_matmul_no_tiling (%arg0: tensor <128 x128 xf32 >, %arg1: tensor <128 x128 xf32 >, %arg2: tensor <128 x128 xf32 >) -> tensor <128 x128 xf32 > {
231
+ %0 = linalg.matmul ins (%arg0 , %arg1: tensor <128 x128 xf32 >, tensor <128 x128 xf32 >)
232
+ outs (%arg2: tensor <128 x128 xf32 >)
233
+ -> tensor <128 x128 xf32 >
234
+ %c0 = arith.constant 0.0 : f32
235
+ %1 = linalg.generic {index ing_maps = [#map ], iterator_types = [" parallel" , " parallel" ]} outs (%0: tensor <128 x128 xf32 >) {
236
+ ^bb0 (%out: f32 ):
237
+ %2 = arith.maximumf %out , %c0 : f32
238
+ linalg.yield %2 : f32
239
+ } -> tensor <128 x128 xf32 >
240
+ return %1 : tensor <128 x128 xf32 >
241
+ }
242
+
243
+ // CONF1-LABEL: func.func @generic_matmul_no_tiling
244
+ // CONF1-NOT: scf.for
245
+ // CONF2-LABEL: func.func @generic_matmul_no_tiling
246
+ // CONF2-NOT: scf.for
247
+
248
+ // -----
0 commit comments