Skip to content

Commit 04258fe

Browse files
authored
[mlir][XeGPU][XeGPUUnroll] Support new syntax with offsets moved to load_nd/store_nd/prefetch_nd (#160323)
Adds support for new syntax in XeGPUUnroll for: 1. `create_nd_desc` without offsets 2. `load_nd` with offsets 3. `store_nd` with offsets 4. `prefetch_nd` with offsets `create_nd_desc with offsets` + `load_nd with offsets` won't be lowered correctly. In this case the IR would still have two unrealized conversions that will fail later in the pipeline. The offsets computation for the unrolled tile is now moved from descriptors to load/store/prefetch operations. The resulted IR now has one single descriptor that is being iterated in load/store/prefetch ops. <details><summary>old/new behavior examples</summary> ```mlir // before unroll pass: gpu.func @load_nd(%src: memref<256x318xf32>) -> vector<24x32xf32> { %tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32> gpu.return %ld : vector<24x32xf32> } // after unroll pass (offsets in create_nd_desc): gpu.func @create_nd_tdesc2(%arg0: memref<256x318xf32>) -> vector<24x32xf32> { %cst = arith.constant dense<0.000000e+00> : vector<24x32xf32> %c24 = arith.constant 24 : index %c32 = arith.constant 32 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index // create 6 descriptors for each tile %0 = xegpu.create_nd_tdesc %arg0[%c8, %c16] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32> %1 = xegpu.create_nd_tdesc %arg0[%c8, %c32] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32> %2 = xegpu.create_nd_tdesc %arg0[%c16, %c16] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32> %3 = xegpu.create_nd_tdesc %arg0[%c16, %c32] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32> %4 = xegpu.create_nd_tdesc %arg0[%c24, %c16] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32> %5 = xegpu.create_nd_tdesc %arg0[%c24, %c32] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32> %6 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> %7 = xegpu.load_nd %1 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> %8 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> %9 = xegpu.load_nd %3 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> %10 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> %11 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> ... } // after unroll pass (offsets in load_nd): gpu.func @load_nd(%arg0: memref<256x318xf32>) -> vector<24x32xf32> { %cst = arith.constant dense<0.000000e+00> : vector<24x32xf32> %c24 = arith.constant 24 : index %c32 = arith.constant 32 : index %c16 = arith.constant 16 : index %c8 = arith.constant 8 : index // create only one descriptor with proper tile shape %0 = xegpu.create_nd_tdesc %arg0 : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32> // compute tile offsets at the operation (using only one descriptor) %1 = xegpu.load_nd %0[%c8, %c16] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> %2 = xegpu.load_nd %0[%c8, %c32] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> %3 = xegpu.load_nd %0[%c16, %c16] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> %4 = xegpu.load_nd %0[%c16, %c32] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> %5 = xegpu.load_nd %0[%c24, %c16] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> %6 = xegpu.load_nd %0[%c24, %c32] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> ... } ``` </details> --------- Signed-off-by: dchigarev <[email protected]>
1 parent ca61a9d commit 04258fe

File tree

5 files changed

+198
-64
lines changed

5 files changed

+198
-64
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
1010
#define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
1111

12+
#include "mlir/IR/Operation.h"
1213
#include "llvm/ADT/SmallVector.h"
1314
#include "llvm/Support/LogicalResult.h"
14-
#include "mlir/IR/Operation.h"
1515

1616
#include <functional>
1717
#include <optional>
@@ -47,9 +47,11 @@ struct UnrollOptions {
4747

4848
/// Function that converts a ShapedType (TensorDescType or VectorType)
4949
/// into the unrolled type based on the tileShape. It returns a vector of
50-
/// types representing the unrolled types for simplicity.
50+
/// types representing the unrolled types for simplicity. When
51+
/// `returnSingleType` is true, it returns a vector containing only one single
52+
/// unrolled type.
5153
using UnrolledTypeFnType = std::function<SmallVector<Type>(
52-
ShapedType type, ArrayRef<int64_t> tileShape)>;
54+
ShapedType type, ArrayRef<int64_t> tileShape, bool returnSingleType)>;
5355
UnrolledTypeFnType getUnrolledTypes = nullptr;
5456
UnrollOptions &setUnrolledTypesFn(UnrolledTypeFnType fn) {
5557
getUnrolledTypes = std::move(fn);

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,8 @@ void XeGPUBlockingPass::runOnOperation() {
319319

320320
options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
321321

322-
options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
322+
options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape,
323+
bool returnSingleType = false) {
323324
Type elemTy = type.getElementType();
324325
Type newTy;
325326

@@ -352,6 +353,8 @@ void XeGPUBlockingPass::runOnOperation() {
352353
newTy = type.clone(tileShape, elemTy);
353354
}
354355

356+
if (returnSingleType)
357+
return SmallVector<Type>{newTy};
355358
std::optional<SmallVector<int64_t>> ratio =
356359
computeShapeRatio(type.getShape(), tileShape);
357360
assert(ratio && "The shape of the type must be a multiple of tileShape.");

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 124 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
5656
}
5757

5858
SmallVector<Type> getUnrolledTypes(ShapedType type,
59-
ArrayRef<int64_t> tileShape) const {
60-
return options.getUnrolledTypes(type, tileShape);
59+
ArrayRef<int64_t> tileShape,
60+
bool returnSingleType = false) const {
61+
return options.getUnrolledTypes(type, tileShape, returnSingleType);
6162
}
6263

6364
/// Emulate the the unpack behavior using insert_strided_slice for VectorType
@@ -121,53 +122,79 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
121122
xegpu::UnrollOptions options;
122123
};
123124

125+
// Generic helper function for unrolling operations with offsets.
126+
//
127+
// Iterates over tile offsets within the tensor descriptor shape and calls
128+
// the provided createOp function for each computed offset. This is used by
129+
// operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
130+
// have explicit offsets that need to be adjusted for each unrolled tile.
131+
SmallVector<Value> computeUnrolledOffsets(
132+
SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
133+
ArrayRef<int64_t> targetShape,
134+
const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
135+
Location loc, PatternRewriter &rewriter) {
136+
int64_t rank = tdescTy.getRank();
137+
ArrayRef<int64_t> shape = tdescTy.getShape();
138+
139+
auto addi = [&](OpFoldResult a, int64_t b) -> Value {
140+
std::optional<int64_t> maybeInt = getConstantIntValue(a);
141+
if (maybeInt) {
142+
return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
143+
} else {
144+
auto aV = llvm::cast<Value>(a);
145+
auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
146+
return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
147+
}
148+
};
149+
150+
SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
151+
llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
152+
auto validIdxes =
153+
llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
154+
155+
SmallVector<Value> newOps;
156+
for (SmallVector<int64_t> offsets :
157+
StaticTileOffsetRange(shape, targetShape)) {
158+
159+
for (auto [idx, oldOff, offset] :
160+
llvm::zip(validIdxes, oldOffsets, offsets))
161+
mixedOffsets[idx] = addi(oldOff, offset);
162+
163+
auto newOp = createOp(mixedOffsets);
164+
newOps.push_back(newOp);
165+
}
166+
return newOps;
167+
}
168+
124169
struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
125170
using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
126171
LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
127172
PatternRewriter &rewriter) const override {
128173
Location loc = op.getLoc();
129174
xegpu::TensorDescType tdescTy = op.getType();
130-
int64_t rank = tdescTy.getRank();
131-
ArrayRef<int64_t> shape = tdescTy.getShape();
132175

133176
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
134177
if (!targetShape)
135178
return failure();
136179

137-
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
138-
139-
auto addi = [&](OpFoldResult a, int64_t b) -> Value {
140-
std::optional<int64_t> maybeInt = getConstantIntValue(a);
141-
if (maybeInt) {
142-
return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
143-
} else {
144-
auto aV = llvm::cast<Value>(a);
145-
auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
146-
return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
147-
}
148-
};
149-
150-
SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
151-
152-
// For n-D memrefs where n > rank, we need to handle the last `rank`
153-
// dimensions only, and keep the first `n-rank` dimensions as is.
154-
SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
155-
llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
156-
auto validIdxes =
157-
llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
158-
159180
SmallVector<Value> newOps;
160-
for (SmallVector<int64_t> offsets :
161-
StaticTileOffsetRange(shape, *targetShape)) {
162-
163-
for (auto [idx, oldOff, offset] :
164-
llvm::zip(validIdxes, oldOffsets, offsets))
165-
mixedOffsets[idx] = addi(oldOff, offset);
166181

182+
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
183+
bool hasOffsets = op.getMixedOffsets().size() != 0;
184+
if (!hasOffsets) {
167185
auto newOp = xegpu::CreateNdDescOp::create(
168-
rewriter, loc, newTdescTy, op.getSource(), mixedOffsets,
169-
op.getMixedSizes(), op.getMixedStrides());
186+
rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
187+
op.getMixedStrides());
170188
newOps.push_back(newOp);
189+
} else {
190+
auto createOp = [&](SmallVector<OpFoldResult> offsets) -> Value {
191+
return xegpu::CreateNdDescOp::create(
192+
rewriter, loc, newTdescTy, op.getSource(), offsets,
193+
op.getMixedSizes(), op.getMixedStrides());
194+
};
195+
196+
newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
197+
*targetShape, createOp, loc, rewriter);
171198
}
172199
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
173200
rewriter.replaceOp(op, castOp);
@@ -216,17 +243,30 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
216243
return failure();
217244

218245
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
219-
if ((offsetSize != 0) || op.getConstOffsetsAttr())
220-
return failure();
246+
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
247+
248+
SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
249+
tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
221250

222-
SmallVector<Type> convertedTdescTypes =
223-
getUnrolledTypes(tdescTy, *targetShape);
224251
SmallVector<Value> convertedTdesc = pack(
225252
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
226253

227-
for (auto t : convertedTdesc)
228-
xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
229-
op->getAttrs());
254+
if (!hasOffsets) {
255+
for (auto t : convertedTdesc)
256+
xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
257+
op->getAttrs());
258+
} else {
259+
auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
260+
xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
261+
op.getL1HintAttr(), op.getL2HintAttr(),
262+
op.getL3HintAttr());
263+
// return dummy Value to satisfy function's signature
264+
return nullptr;
265+
};
266+
267+
computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
268+
createPrefetch, loc, rewriter);
269+
}
230270

231271
rewriter.eraseOp(op);
232272
return success();
@@ -247,22 +287,33 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
247287
return failure();
248288

249289
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
250-
if ((offsetSize != 0) || op.getConstOffsetsAttr())
251-
return failure();
290+
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
252291

253292
Type elemTy = tdescTy.getElementType();
254293
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
255294

256-
SmallVector<Type> convertedTdescTypes =
257-
getUnrolledTypes(tdescTy, *targetShape);
295+
SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
296+
tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
297+
258298
SmallVector<Value> convertedTdescs = pack(
259299
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
260-
261300
SmallVector<Value> newOps;
262-
for (auto t : convertedTdescs) {
263-
auto newOp =
264-
xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, op->getAttrs());
265-
newOps.push_back(newOp);
301+
302+
if (!hasOffsets) {
303+
for (auto t : convertedTdescs) {
304+
auto newOp = xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
305+
op->getAttrs());
306+
newOps.push_back(newOp);
307+
}
308+
} else {
309+
auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
310+
return xegpu::LoadNdOp::create(
311+
rewriter, loc, newValueTy, convertedTdescs[0], offsets,
312+
op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
313+
op.getL2HintAttr(), op.getL3HintAttr());
314+
};
315+
newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
316+
*targetShape, createLoad, loc, rewriter);
266317
}
267318

268319
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
@@ -285,22 +336,36 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
285336
return failure();
286337

287338
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
288-
if ((offsetSize != 0) || op.getConstOffsetsAttr())
289-
return failure();
339+
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
290340

291341
SmallVector<Type> convertedValTypes =
292342
getUnrolledTypes(valueTy, *targetShape);
293-
SmallVector<Type> convertedTdescTypes =
294-
getUnrolledTypes(tdescTy, *targetShape);
343+
SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
344+
tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
295345

296-
SmallVector<Value> convertedValues =
297-
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
298346
SmallVector<Value> convertedTdescs = pack(
299347
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
300348

301-
for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
302-
xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
303-
op.getL2HintAttr(), op.getL3HintAttr());
349+
SmallVector<Value> convertedValues =
350+
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
351+
if (!hasOffsets) {
352+
for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
353+
xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
354+
op.getL2HintAttr(), op.getL3HintAttr());
355+
} else {
356+
size_t valueIndex = 0;
357+
auto createStore = [&](SmallVector<OpFoldResult> offsets) {
358+
xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
359+
convertedTdescs[0], offsets,
360+
op.getL1HintAttr(), op.getL2HintAttr(),
361+
op.getL3HintAttr());
362+
// return dummy Value to satisfy function's signature
363+
return nullptr;
364+
};
365+
366+
computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
367+
createStore, loc, rewriter);
368+
}
304369

305370
rewriter.eraseOp(op);
306371
return success();
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: mlir-opt --test-xegpu-unrolling-patterns -split-input-file %s | FileCheck %s
2+
3+
gpu.module @xevm_test {
4+
5+
// CHECK-LABEL: create_nd_tdesc
6+
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
7+
// CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
8+
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
9+
// CHECK-SAME: !xegpu.tensor_desc<8x16xf32>
10+
// CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
11+
gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
12+
%tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
13+
gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
14+
}
15+
16+
//-----
17+
// CHECK-LABEL: load_nd
18+
// CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32>
19+
// CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
20+
// CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
21+
// CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
22+
gpu.func @load_nd(%src: memref<256x318xf32>) -> vector<24x32xf32> {
23+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
24+
%ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
25+
gpu.return %ld : vector<24x32xf32>
26+
}
27+
28+
//-----
29+
// CHECK-LABEL: load_nd_store_nd
30+
// CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32>
31+
// CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
32+
// CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
33+
// CHECK-COUNT-6: xegpu.store_nd {{.*}}[{{.*}}] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
34+
gpu.func @load_nd_store_nd(%src: memref<256x318xf32>) {
35+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
36+
%ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
37+
xegpu.store_nd %ld, %tdesc[0, 0] : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
38+
gpu.return
39+
}
40+
41+
//-----
42+
// CHECK-LABEL: prefetch_nd_tdesc
43+
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
44+
// CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
45+
// CHECK-COUNT-6: xegpu.prefetch_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32>
46+
gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
47+
%tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
48+
xegpu.prefetch_nd %tdesc[8, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
49+
gpu.return
50+
}
51+
52+
//-----
53+
54+
// CHECK-LABEL: load_nd_offsets_at_both_places
55+
// CHECK-COUNT-2: builtin.unrealized_conversion_cast
56+
gpu.func @load_nd_offsets_at_both_places(%src: memref<256x318xf32>) -> vector<24x32xf32> {
57+
%tdesc = xegpu.create_nd_tdesc %src[16, 8] : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
58+
%ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
59+
gpu.return %ld : vector<24x32xf32>
60+
}
61+
}

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ struct TestXeGPUUnrollingPatterns
113113
});
114114

115115
options.setUnrolledTypesFn(
116-
[&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> {
116+
[&](ShapedType type, ArrayRef<int64_t> tileShape,
117+
bool returnSingleType = false) -> SmallVector<Type> {
117118
Type elemTy = type.getElementType();
118119
Type newTy;
119120

@@ -155,6 +156,8 @@ struct TestXeGPUUnrollingPatterns
155156
newTy = type.clone(tileShape, elemTy);
156157
}
157158

159+
if (returnSingleType)
160+
return SmallVector<Type>{newTy};
158161
std::optional<SmallVector<int64_t>> ratio =
159162
computeShapeRatio(type.getShape(), tileShape);
160163
assert(ratio && "Expecting the ratio to be valid.");

0 commit comments

Comments
 (0)