@@ -56,8 +56,9 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
56
56
}
57
57
58
58
SmallVector<Type> getUnrolledTypes (ShapedType type,
59
- ArrayRef<int64_t > tileShape) const {
60
- return options.getUnrolledTypes (type, tileShape);
59
+ ArrayRef<int64_t > tileShape,
60
+ bool returnSingleType = false ) const {
61
+ return options.getUnrolledTypes (type, tileShape, returnSingleType);
61
62
}
62
63
63
64
// / Emulate the the unpack behavior using insert_strided_slice for VectorType
@@ -121,53 +122,79 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
121
122
xegpu::UnrollOptions options;
122
123
};
123
124
125
+ // Generic helper function for unrolling operations with offsets.
126
+ //
127
+ // Iterates over tile offsets within the tensor descriptor shape and calls
128
+ // the provided createOp function for each computed offset. This is used by
129
+ // operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
130
+ // have explicit offsets that need to be adjusted for each unrolled tile.
131
+ SmallVector<Value> computeUnrolledOffsets (
132
+ SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
133
+ ArrayRef<int64_t > targetShape,
134
+ const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
135
+ Location loc, PatternRewriter &rewriter) {
136
+ int64_t rank = tdescTy.getRank ();
137
+ ArrayRef<int64_t > shape = tdescTy.getShape ();
138
+
139
+ auto addi = [&](OpFoldResult a, int64_t b) -> Value {
140
+ std::optional<int64_t > maybeInt = getConstantIntValue (a);
141
+ if (maybeInt) {
142
+ return arith::ConstantIndexOp::create (rewriter, loc, *maybeInt + b);
143
+ } else {
144
+ auto aV = llvm::cast<Value>(a);
145
+ auto bV = arith::ConstantIndexOp::create (rewriter, loc, b);
146
+ return rewriter.createOrFold <arith::AddIOp>(loc, aV, bV);
147
+ }
148
+ };
149
+
150
+ SmallVector<OpFoldResult> oldOffsets = llvm::to_vector (
151
+ llvm::drop_begin (mixedOffsets, mixedOffsets.size () - rank));
152
+ auto validIdxes =
153
+ llvm::seq<int64_t >(mixedOffsets.size () - rank, mixedOffsets.size ());
154
+
155
+ SmallVector<Value> newOps;
156
+ for (SmallVector<int64_t > offsets :
157
+ StaticTileOffsetRange (shape, targetShape)) {
158
+
159
+ for (auto [idx, oldOff, offset] :
160
+ llvm::zip (validIdxes, oldOffsets, offsets))
161
+ mixedOffsets[idx] = addi (oldOff, offset);
162
+
163
+ auto newOp = createOp (mixedOffsets);
164
+ newOps.push_back (newOp);
165
+ }
166
+ return newOps;
167
+ }
168
+
124
169
struct UnrollCreateNdOp : public UnrollPattern <xegpu::CreateNdDescOp> {
125
170
using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
126
171
LogicalResult matchAndRewrite (xegpu::CreateNdDescOp op,
127
172
PatternRewriter &rewriter) const override {
128
173
Location loc = op.getLoc ();
129
174
xegpu::TensorDescType tdescTy = op.getType ();
130
- int64_t rank = tdescTy.getRank ();
131
- ArrayRef<int64_t > shape = tdescTy.getShape ();
132
175
133
176
std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
134
177
if (!targetShape)
135
178
return failure ();
136
179
137
- auto newTdescTy = getUnrolledTypes (tdescTy, *targetShape)[0 ];
138
-
139
- auto addi = [&](OpFoldResult a, int64_t b) -> Value {
140
- std::optional<int64_t > maybeInt = getConstantIntValue (a);
141
- if (maybeInt) {
142
- return arith::ConstantIndexOp::create (rewriter, loc, *maybeInt + b);
143
- } else {
144
- auto aV = llvm::cast<Value>(a);
145
- auto bV = arith::ConstantIndexOp::create (rewriter, loc, b);
146
- return rewriter.createOrFold <arith::AddIOp>(loc, aV, bV);
147
- }
148
- };
149
-
150
- SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets ();
151
-
152
- // For n-D memrefs where n > rank, we need to handle the last `rank`
153
- // dimensions only, and keep the first `n-rank` dimensions as is.
154
- SmallVector<OpFoldResult> oldOffsets = llvm::to_vector (
155
- llvm::drop_begin (mixedOffsets, mixedOffsets.size () - rank));
156
- auto validIdxes =
157
- llvm::seq<int64_t >(mixedOffsets.size () - rank, mixedOffsets.size ());
158
-
159
180
SmallVector<Value> newOps;
160
- for (SmallVector<int64_t > offsets :
161
- StaticTileOffsetRange (shape, *targetShape)) {
162
-
163
- for (auto [idx, oldOff, offset] :
164
- llvm::zip (validIdxes, oldOffsets, offsets))
165
- mixedOffsets[idx] = addi (oldOff, offset);
166
181
182
+ auto newTdescTy = getUnrolledTypes (tdescTy, *targetShape)[0 ];
183
+ bool hasOffsets = op.getMixedOffsets ().size () != 0 ;
184
+ if (!hasOffsets) {
167
185
auto newOp = xegpu::CreateNdDescOp::create (
168
- rewriter, loc, newTdescTy, op.getSource (), mixedOffsets ,
169
- op.getMixedSizes (), op. getMixedStrides ());
186
+ rewriter, loc, newTdescTy, op.getSource (), op. getMixedSizes () ,
187
+ op.getMixedStrides ());
170
188
newOps.push_back (newOp);
189
+ } else {
190
+ auto createOp = [&](SmallVector<OpFoldResult> offsets) -> Value {
191
+ return xegpu::CreateNdDescOp::create (
192
+ rewriter, loc, newTdescTy, op.getSource (), offsets,
193
+ op.getMixedSizes (), op.getMixedStrides ());
194
+ };
195
+
196
+ newOps = computeUnrolledOffsets (op.getMixedOffsets (), tdescTy,
197
+ *targetShape, createOp, loc, rewriter);
171
198
}
172
199
Value castOp = unpack (newOps, tdescTy, *targetShape, loc, rewriter);
173
200
rewriter.replaceOp (op, castOp);
@@ -216,17 +243,30 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
216
243
return failure ();
217
244
218
245
int64_t offsetSize = static_cast <int64_t >(op.getOffsets ().size ());
219
- if ((offsetSize != 0 ) || op.getConstOffsetsAttr ())
220
- return failure ();
246
+ bool hasOffsets = (offsetSize != 0 ) || op.getConstOffsetsAttr ();
247
+
248
+ SmallVector<Type> convertedTdescTypes = getUnrolledTypes (
249
+ tdescTy, *targetShape, /* returnSingleType*/ hasOffsets);
221
250
222
- SmallVector<Type> convertedTdescTypes =
223
- getUnrolledTypes (tdescTy, *targetShape);
224
251
SmallVector<Value> convertedTdesc = pack (
225
252
op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
226
253
227
- for (auto t : convertedTdesc)
228
- xegpu::PrefetchNdOp::create (rewriter, loc, TypeRange (), t,
229
- op->getAttrs ());
254
+ if (!hasOffsets) {
255
+ for (auto t : convertedTdesc)
256
+ xegpu::PrefetchNdOp::create (rewriter, loc, TypeRange (), t,
257
+ op->getAttrs ());
258
+ } else {
259
+ auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
260
+ xegpu::PrefetchNdOp::create (rewriter, loc, convertedTdesc[0 ], offsets,
261
+ op.getL1HintAttr (), op.getL2HintAttr (),
262
+ op.getL3HintAttr ());
263
+ // return dummy Value to satisfy function's signature
264
+ return nullptr ;
265
+ };
266
+
267
+ computeUnrolledOffsets (op.getMixedOffsets (), tdescTy, *targetShape,
268
+ createPrefetch, loc, rewriter);
269
+ }
230
270
231
271
rewriter.eraseOp (op);
232
272
return success ();
@@ -247,22 +287,33 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
247
287
return failure ();
248
288
249
289
int64_t offsetSize = static_cast <int64_t >(op.getOffsets ().size ());
250
- if ((offsetSize != 0 ) || op.getConstOffsetsAttr ())
251
- return failure ();
290
+ bool hasOffsets = (offsetSize != 0 ) || op.getConstOffsetsAttr ();
252
291
253
292
Type elemTy = tdescTy.getElementType ();
254
293
VectorType newValueTy = valueTy.cloneWith (*targetShape, elemTy);
255
294
256
- SmallVector<Type> convertedTdescTypes =
257
- getUnrolledTypes (tdescTy, *targetShape);
295
+ SmallVector<Type> convertedTdescTypes = getUnrolledTypes (
296
+ tdescTy, *targetShape, /* returnSingleType*/ hasOffsets);
297
+
258
298
SmallVector<Value> convertedTdescs = pack (
259
299
op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
260
-
261
300
SmallVector<Value> newOps;
262
- for (auto t : convertedTdescs) {
263
- auto newOp =
264
- xegpu::LoadNdOp::create (rewriter, loc, newValueTy, t, op->getAttrs ());
265
- newOps.push_back (newOp);
301
+
302
+ if (!hasOffsets) {
303
+ for (auto t : convertedTdescs) {
304
+ auto newOp = xegpu::LoadNdOp::create (rewriter, loc, newValueTy, t,
305
+ op->getAttrs ());
306
+ newOps.push_back (newOp);
307
+ }
308
+ } else {
309
+ auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
310
+ return xegpu::LoadNdOp::create (
311
+ rewriter, loc, newValueTy, convertedTdescs[0 ], offsets,
312
+ op.getPackedAttr (), op.getTransposeAttr (), op.getL1HintAttr (),
313
+ op.getL2HintAttr (), op.getL3HintAttr ());
314
+ };
315
+ newOps = computeUnrolledOffsets (op.getMixedOffsets (), tdescTy,
316
+ *targetShape, createLoad, loc, rewriter);
266
317
}
267
318
268
319
Value castOp = unpack (newOps, op.getType (), *targetShape, loc, rewriter);
@@ -285,22 +336,36 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
285
336
return failure ();
286
337
287
338
int64_t offsetSize = static_cast <int64_t >(op.getOffsets ().size ());
288
- if ((offsetSize != 0 ) || op.getConstOffsetsAttr ())
289
- return failure ();
339
+ bool hasOffsets = (offsetSize != 0 ) || op.getConstOffsetsAttr ();
290
340
291
341
SmallVector<Type> convertedValTypes =
292
342
getUnrolledTypes (valueTy, *targetShape);
293
- SmallVector<Type> convertedTdescTypes =
294
- getUnrolledTypes ( tdescTy, *targetShape);
343
+ SmallVector<Type> convertedTdescTypes = getUnrolledTypes (
344
+ tdescTy, *targetShape, /* returnSingleType */ hasOffsets );
295
345
296
- SmallVector<Value> convertedValues =
297
- pack (op.getValue (), convertedValTypes, *targetShape, loc, rewriter);
298
346
SmallVector<Value> convertedTdescs = pack (
299
347
op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
300
348
301
- for (auto [v, t] : llvm::zip (convertedValues, convertedTdescs))
302
- xegpu::StoreNdOp::create (rewriter, loc, v, t, op.getL1HintAttr (),
303
- op.getL2HintAttr (), op.getL3HintAttr ());
349
+ SmallVector<Value> convertedValues =
350
+ pack (op.getValue (), convertedValTypes, *targetShape, loc, rewriter);
351
+ if (!hasOffsets) {
352
+ for (auto [v, t] : llvm::zip (convertedValues, convertedTdescs))
353
+ xegpu::StoreNdOp::create (rewriter, loc, v, t, op.getL1HintAttr (),
354
+ op.getL2HintAttr (), op.getL3HintAttr ());
355
+ } else {
356
+ size_t valueIndex = 0 ;
357
+ auto createStore = [&](SmallVector<OpFoldResult> offsets) {
358
+ xegpu::StoreNdOp::create (rewriter, loc, convertedValues[valueIndex++],
359
+ convertedTdescs[0 ], offsets,
360
+ op.getL1HintAttr (), op.getL2HintAttr (),
361
+ op.getL3HintAttr ());
362
+ // return dummy Value to satisfy function's signature
363
+ return nullptr ;
364
+ };
365
+
366
+ computeUnrolledOffsets (op.getMixedOffsets (), tdescTy, *targetShape,
367
+ createStore, loc, rewriter);
368
+ }
304
369
305
370
rewriter.eraseOp (op);
306
371
return success ();
0 commit comments