21
21
#include " TPP/Dialect/Perf/PerfDialect.h"
22
22
#include " TPP/Dialect/Xsmm/XsmmDialect.h"
23
23
#include " TPP/PassUtils.h"
24
+ #include " mlir/Transforms/Passes.h"
24
25
25
26
using namespace mlir ;
26
27
using namespace mlir ::tpp;
@@ -97,8 +98,14 @@ struct DefaultTppPasses
97
98
// Bufferize: tensor->memref.
98
99
pm.addPass (createBufferize ());
99
100
100
- // Lower all Tile operations.
101
- pm.addNestedPass <func::FuncOp>(createLinalgLowering ());
101
+ if (linalgToVector) {
102
+ pm.addNestedPass <func::FuncOp>(createVectorizationPass ());
103
+ pm.addNestedPass <func::FuncOp>(createVectorContractPass ());
104
+ pm.addNestedPass <func::FuncOp>(createCanonicalizerPass ());
105
+ } else {
106
+ // Lower all Tile operations.
107
+ pm.addNestedPass <func::FuncOp>(createLinalgLowering ());
108
+ }
102
109
pm.addPass (createCleanup ());
103
110
}
104
111
@@ -109,10 +116,18 @@ struct DefaultTppPasses
109
116
// Low leve parallelization passes.
110
117
LowLevelParallelizationOptions LowLevelParallelization{parallelTaskGrid};
111
118
pm.addPass (createLowLevelParallelization (LowLevelParallelization));
112
-
113
- // Covert all local TPP-related dialects.
114
- pm.addPass (createLocalDialectsLowering ());
115
-
119
+ if (linalgToVector) {
120
+ pm.addPass (createConvertVectorToSCFPass ());
121
+ } else {
122
+ // Covert all local TPP-related dialects.
123
+ pm.addPass (createLocalDialectsLowering ());
124
+
125
+ pm.addNestedPass <func::FuncOp>(createIntelAMXTileConfigInsertionPass ());
126
+ pm.addNestedPass <func::FuncOp>(createCanonicalizerPass ());
127
+ pm.addNestedPass <func::FuncOp>(createLoopInvariantCodeMotionPass ());
128
+ pm.addNestedPass <func::FuncOp>(createCanonicalizerPass ());
129
+ pm.addNestedPass <func::FuncOp>(createIntelAMXTileConfigHoistingPass ());
130
+ }
116
131
// Clean up after the default pipeline.
117
132
pm.addNestedPass <func::FuncOp>(createPostprocessing ());
118
133
}
0 commit comments