Skip to content

Commit 0d504c7

Browse files
authored
Brgemm register tiling support for tensor type (#1016)
This patch extends PR: #1005 to extend register tiling support for `tensor` type. Also, adds new unit test-cases.
1 parent f8d8a16 commit 0d504c7

File tree

2 files changed

+105
-11
lines changed

2 files changed

+105
-11
lines changed

lib/TPP/Transforms/BrgemmLinalgTiling.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
4848

4949
LogicalResult matchAndRewrite(BrgemmOp brgemmOp,
5050
PatternRewriter &rewriter) const override {
51-
if (!brgemmOp.hasPureBufferSemantics())
52-
return failure();
5351

5452
// Check whether the tile sizes are valid
5553
if (options.registerTileShape.size() != 3)
@@ -177,7 +175,7 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
177175
if (failed(tiledOp)) {
178176
return failure();
179177
}
180-
rewriter.replaceOp(brgemmOp, tiledOp->op->getResults());
178+
rewriter.replaceOp(brgemmOp, tiledOp->tensorResults);
181179

182180
return success();
183181
}

test/Passes/pass-tile-brgemm-linalg-matmul.mlir

Lines changed: 104 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,34 @@ module {
3535

3636
// -----
3737

38+
module {
39+
func.func @brgemm_tensor_type_tiling(%arg0: tensor<128x256x512xf32>, %arg1: tensor<128x512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
40+
%0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<128x256x512xf32>, tensor<128x512x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
41+
return %0 : tensor<256x256xf32>
42+
}
43+
}
44+
45+
46+
// CONF1-LABEL: func.func @brgemm_tensor_type_tiling
47+
// CONF1-DAG: %[[C0:.+]] = arith.constant 0 : index
48+
// CONF1-DAG: %[[C256:.+]] = arith.constant 256 : index
49+
// CONF1-DAG: %[[C8:.+]] = arith.constant 8 : index
50+
// CONF1-DAG: %[[C32:.+]] = arith.constant 32 : index
51+
// CONF1-DAG: %[[C128:.+]] = arith.constant 128 : index
52+
// CONF1-DAG: %[[C1:.+]] = arith.constant 1 : index
53+
// CONF1-DAG: %[[C512:.+]] = arith.constant 512 : index
54+
// CONF1: %0 = scf.for %[[I:.+]] = %[[C0]] to %[[C256]] step %[[C8]] iter_args(%arg4 = %arg2) -> (tensor<256x256xf32>) {
55+
// CONF1-NEXT: %1 = scf.for %[[J:.+]] = %[[C0]] to %[[C256]] step %[[C32]] iter_args(%arg6 = %arg4) -> (tensor<256x256xf32>) {
56+
// CONF1-NEXT: %2 = scf.for %[[K:.+]] = %[[C0]] to %[[C128]] step %[[C1]] iter_args(%arg8 = %arg6) -> (tensor<256x256xf32>) {
57+
// CONF1-NEXT: %3 = scf.for %[[L:.+]] = %[[C0]] to %[[C512]] step %[[C1]] iter_args(%arg10 = %arg8) -> (tensor<256x256xf32>) {
58+
// CONF1-NEXT: %extracted_slice = tensor.extract_slice %arg0[%[[K]], %[[I]], %[[L]]] [1, 8, 1] [1, 1, 1] : tensor<128x256x512xf32> to tensor<1x8x1xf32>
59+
// CONF1-NEXT: %extracted_slice_0 = tensor.extract_slice %arg1[%[[K]], %[[L]], %[[J]]] [1, 1, 32] [1, 1, 1] : tensor<128x512x256xf32> to tensor<1x1x32xf32>
60+
// CONF1-NEXT: %extracted_slice_1 = tensor.extract_slice %arg10[%[[I]], %[[J]]] [8, 32] [1, 1] : tensor<256x256xf32> to tensor<8x32xf32>
61+
// CONF1-NEXT: %4 = linalg.batch_reduce_matmul ins(%extracted_slice, %extracted_slice_0 : tensor<1x8x1xf32>, tensor<1x1x32xf32>) outs(%extracted_slice_1 : tensor<8x32xf32>) -> tensor<8x32xf32>
62+
// CONF1-NEXT: %inserted_slice = tensor.insert_slice %4 into %arg10[%[[I]], %[[J]]] [8, 32] [1, 1] : tensor<8x32xf32> into tensor<256x256xf32>
63+
64+
// -----
65+
3866
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
3967
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
4068
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
@@ -124,17 +152,48 @@ module {
124152

125153
// -----
126154

155+
156+
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
157+
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
158+
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
127159
module {
128-
func.func @brgemm_tensor_type_no_tiling(%arg0: tensor<128x256x512xf32>, %arg1: tensor<128x512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
129-
%0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<128x256x512xf32>, tensor<128x512x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
130-
return %0 : tensor<256x256xf32>
160+
func.func @gemm_64tiles_do_tiling_bf16_tensor(%arg0: tensor<4x16x64x64xbf16>) -> tensor<4x16x64x64xbf16> {
161+
%cst = arith.constant dense<1.000000e+00> : tensor<16x32x64x2xbf16>
162+
%cst_0 = arith.constant 0.000000e+00 : bf16
163+
%0 = bufferization.alloc_tensor() : tensor<4x16x64x64xbf16>
164+
%expanded = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [4, 16, 64, 32, 2] : tensor<4x16x64x64xbf16> into tensor<4x16x64x32x2xbf16>
165+
%1 = scf.forall (%arg1, %arg2) in (4, 16) shared_outs(%arg3 = %0) -> (tensor<4x16x64x64xbf16>) {
166+
%extracted_slice = tensor.extract_slice %arg3[%arg1, %arg2, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<4x16x64x64xbf16> to tensor<64x64xbf16>
167+
%2 = linalg.fill ins(%cst_0 : bf16) outs(%extracted_slice : tensor<64x64xbf16>) -> tensor<64x64xbf16>
168+
%extracted_slice_1 = tensor.extract_slice %expanded[%arg1, 0, 0, 0, 0] [1, 16, 64, 32, 2] [1, 1, 1, 1, 1] : tensor<4x16x64x32x2xbf16> to tensor<16x64x32x2xbf16>
169+
%3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_1, %cst : tensor<16x64x32x2xbf16>, tensor<16x32x64x2xbf16>) outs(%2 : tensor<64x64xbf16>) {
170+
^bb0(%in: bf16, %in_2: bf16, %out: bf16):
171+
%4 = arith.mulf %in, %in_2 : bf16
172+
%5 = arith.addf %out, %4 : bf16
173+
linalg.yield %5 : bf16
174+
} -> tensor<64x64xbf16>
175+
scf.forall.in_parallel {
176+
tensor.parallel_insert_slice %3 into %arg3[%arg1, %arg2, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<64x64xbf16> into tensor<4x16x64x64xbf16>
177+
}
178+
}
179+
return %1 : tensor<4x16x64x64xbf16>
131180
}
132181
}
133182

134-
135-
// CONF1-LABEL: func.func @brgemm_tensor_type_no_tiling
136-
// CONF1-NOT: scf.for
137-
// CONF2-NOT: scf.for
183+
// CONF2-LABEL: func.func @gemm_64tiles_do_tiling_bf16_tensor
184+
// CONF2-DAG: %[[C1:.+]] = arith.constant 1 : index
185+
// CONF2-DAG: %[[C32:.+]] = arith.constant 32 : index
186+
// CONF2-DAG: %[[C64:.+]] = arith.constant 64 : index
187+
// CONF2-DAG: %[[C16:.+]] = arith.constant 16 : index
188+
// CONF2-DAG: %[[C0:.+]] = arith.constant 0 : index
189+
// CONF2: %3 = scf.for %[[I:.+]] = %[[C0]] to %[[C64]] step %[[C32]] iter_args(%arg5 = %2) -> (tensor<64x64xbf16>)
190+
// CONF2-NEXT: %4 = scf.for %[[J:.+]] = %[[C0]] to %[[C64]] step %[[C32]] iter_args(%arg7 = %arg5) -> (tensor<64x64xbf16>)
191+
// CONF2-NEXT: %5 = scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%arg9 = %arg7) -> (tensor<64x64xbf16>)
192+
// CONF2-NEXT: %6 = scf.for %[[L:.+]] = %[[C0]] to %[[C32]] step %[[C16]] iter_args(%arg11 = %arg9) -> (tensor<64x64xbf16>)
193+
// CONF2-NEXT: %extracted_slice_2 = tensor.extract_slice %extracted_slice_1[%[[K]], %[[I]], %[[L]], 0] [1, 32, 16, 2] [1, 1, 1, 1] : tensor<16x64x32x2xbf16> to tensor<1x32x16x2xbf16>
194+
// CONF2-NEXT: %extracted_slice_3 = tensor.extract_slice %cst[%[[K]], %[[L]], %[[J]], 0] [1, 16, 32, 2] [1, 1, 1, 1] : tensor<16x32x64x2xbf16> to tensor<1x16x32x2xbf16>
195+
// CONF2-NEXT: %extracted_slice_4 = tensor.extract_slice %arg11[%[[I]], %[[J]]] [32, 32] [1, 1] : tensor<64x64xbf16> to tensor<32x32xbf16>
196+
// CONF2-NEXT: %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice_2, %extracted_slice_3 : tensor<1x32x16x2xbf16>, tensor<1x16x32x2xbf16>) outs(%extracted_slice_4 : tensor<32x32xbf16>)
138197

139198
// -----
140199

@@ -146,7 +205,44 @@ module {
146205
}
147206
}
148207

149-
150208
// CONF1-LABEL: func.func @matmul_no_tiling
151209
// CONF1-NOT: scf.for
210+
// CONF2-LABEL: func.func @matmul_no_tiling
152211
// CONF2-NOT: scf.for
212+
213+
// -----
214+
215+
func.func @batch_matmul_no_tiling(%arg0: tensor<512x32x64xf32>, %arg1: tensor<512x64x32xf32>) -> tensor<512x32x32xf32> {
216+
%0 = tensor.empty() : tensor<512x32x32xf32>
217+
%1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x32x64xf32>, tensor<512x64x32xf32>)
218+
outs(%0 : tensor<512x32x32xf32>) -> tensor<512x32x32xf32>
219+
return %1 : tensor<512x32x32xf32>
220+
}
221+
222+
// CONF1-LABEL: func.func @batch_matmul_no_tiling
223+
// CONF1-NOT: scf.for
224+
// CONF2-LABEL: func.func @batch_matmul_no_tiling
225+
// CONF2-NOT: scf.for
226+
227+
// -----
228+
229+
#map = affine_map<(d0, d1) -> (d0, d1)>
230+
func.func @generic_matmul_no_tiling(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
231+
%0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
232+
outs(%arg2: tensor<128x128xf32>)
233+
-> tensor<128x128xf32>
234+
%c0 = arith.constant 0.0 : f32
235+
%1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%0: tensor<128x128xf32>) {
236+
^bb0(%out: f32):
237+
%2 = arith.maximumf %out, %c0 : f32
238+
linalg.yield %2 : f32
239+
} -> tensor<128x128xf32>
240+
return %1 : tensor<128x128xf32>
241+
}
242+
243+
// CONF1-LABEL: func.func @generic_matmul_no_tiling
244+
// CONF1-NOT: scf.for
245+
// CONF2-LABEL: func.func @generic_matmul_no_tiling
246+
// CONF2-NOT: scf.for
247+
248+
// -----

0 commit comments

Comments
 (0)