25
25
#include " mlir/IR/BuiltinAttributes.h"
26
26
#include " mlir/IR/IntegerSet.h"
27
27
#include " mlir/IR/Visitors.h"
28
- #include " mlir/Transforms/DialectConversion .h"
28
+ #include " mlir/Transforms/WalkPatternRewriteDriver .h"
29
29
#include " llvm/ADT/DenseMap.h"
30
30
#include " llvm/Support/Debug.h"
31
31
#include < optional>
@@ -451,10 +451,10 @@ static void rewriteStore(fir::StoreOp storeOp,
451
451
}
452
452
453
453
static void rewriteMemoryOps (Block *block, mlir::PatternRewriter &rewriter) {
454
- for (auto &bodyOp : block->getOperations ()) {
454
+ for (auto &bodyOp : llvm::make_early_inc_range ( block->getOperations () )) {
455
455
if (isa<fir::LoadOp>(bodyOp))
456
456
rewriteLoad (cast<fir::LoadOp>(bodyOp), rewriter);
457
- if (isa<fir::StoreOp>(bodyOp))
457
+ else if (isa<fir::StoreOp>(bodyOp))
458
458
rewriteStore (cast<fir::StoreOp>(bodyOp), rewriter);
459
459
}
460
460
}
@@ -476,6 +476,8 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
476
476
loop.dump (););
477
477
LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
478
478
functionAnalysis.getChildLoopAnalysis (loop);
479
+ if (!loopAnalysis.canPromoteToAffine ())
480
+ return rewriter.notifyMatchFailure (loop, " cannot promote to affine" );
479
481
auto &loopOps = loop.getBody ()->getOperations ();
480
482
auto resultOp = cast<fir::ResultOp>(loop.getBody ()->getTerminator ());
481
483
auto results = resultOp.getOperands ();
@@ -576,12 +578,14 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
576
578
public:
577
579
using OpRewritePattern::OpRewritePattern;
578
580
AffineIfConversion (mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
579
- : OpRewritePattern(context) {}
581
+ : OpRewritePattern(context), functionAnalysis(afa) {}
580
582
llvm::LogicalResult
581
583
matchAndRewrite (fir::IfOp op,
582
584
mlir::PatternRewriter &rewriter) const override {
583
585
LLVM_DEBUG (llvm::dbgs () << " AffineIfConversion: rewriting if:\n " ;
584
586
op.dump (););
587
+ if (!functionAnalysis.getChildIfAnalysis (op).canPromoteToAffine ())
588
+ return rewriter.notifyMatchFailure (op, " cannot promote to affine" );
585
589
auto &ifOps = op.getThenRegion ().front ().getOperations ();
586
590
auto affineCondition = AffineIfCondition (op.getCondition ());
587
591
if (!affineCondition.hasIntegerSet ()) {
@@ -611,6 +615,8 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
611
615
rewriter.replaceOp (op, affineIf.getOperation ()->getResults ());
612
616
return success ();
613
617
}
618
+
619
+ AffineFunctionAnalysis &functionAnalysis;
614
620
};
615
621
616
622
// / Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases
@@ -627,28 +633,11 @@ class AffineDialectPromotion
627
633
mlir::RewritePatternSet patterns (context);
628
634
patterns.insert <AffineIfConversion>(context, functionAnalysis);
629
635
patterns.insert <AffineLoopConversion>(context, functionAnalysis);
630
- mlir::ConversionTarget target = *context;
631
- target.addLegalDialect <mlir::affine::AffineDialect, FIROpsDialect,
632
- mlir::scf::SCFDialect, mlir::arith::ArithDialect,
633
- mlir::func::FuncDialect>();
634
- target.addDynamicallyLegalOp <IfOp>([&functionAnalysis](fir::IfOp op) {
635
- return !(functionAnalysis.getChildIfAnalysis (op).canPromoteToAffine ());
636
- });
637
- target.addDynamicallyLegalOp <DoLoopOp>([&functionAnalysis](
638
- fir::DoLoopOp op) {
639
- return !(functionAnalysis.getChildLoopAnalysis (op).canPromoteToAffine ());
640
- });
641
-
642
636
LLVM_DEBUG (llvm::dbgs ()
643
637
<< " AffineDialectPromotion: running promotion on: \n " ;
644
638
function.print (llvm::dbgs ()););
645
639
// apply the patterns
646
- if (mlir::failed (mlir::applyPartialConversion (function, target,
647
- std::move (patterns)))) {
648
- mlir::emitError (mlir::UnknownLoc::get (context),
649
- " error in converting to affine dialect\n " );
650
- signalPassFailure ();
651
- }
640
+ walkAndApplyPatterns (function, std::move (patterns));
652
641
}
653
642
};
654
643
} // namespace
0 commit comments