Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,18 @@ class DoConcurrentConversion
mlir::SymbolTable &moduleSymbolTable;
};

/// A listener that forwards notifyOperationErased to the given callback.
struct CallbackListener : public mlir::RewriterBase::Listener {
CallbackListener(std::function<void(mlir::Operation *op)> onOperationErased)
: onOperationErased(onOperationErased) {}

void notifyOperationErased(mlir::Operation *op) override {
onOperationErased(op);
}

std::function<void(mlir::Operation *op)> onOperationErased;
};

class DoConcurrentConversionPass
: public flangomp::impl::DoConcurrentConversionPassBase<
DoConcurrentConversionPass> {
Expand All @@ -468,6 +480,10 @@ class DoConcurrentConversionPass
}

llvm::DenseSet<fir::DoConcurrentOp> concurrentLoopsToSkip;
CallbackListener callbackListener([&](mlir::Operation *op) {
if (auto loop = mlir::dyn_cast<fir::DoConcurrentOp>(op))
concurrentLoopsToSkip.erase(loop);
});
mlir::RewritePatternSet patterns(context);
patterns.insert<DoConcurrentConversion>(
context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
Expand All @@ -480,8 +496,11 @@ class DoConcurrentConversionPass
target.markUnknownOpDynamicallyLegal(
[](mlir::Operation *) { return true; });

if (mlir::failed(
mlir::applyFullConversion(module, target, std::move(patterns)))) {
mlir::ConversionConfig config;
config.allowPatternRollback = false;
config.listener = &callbackListener;
if (mlir::failed(mlir::applyFullConversion(module, target,
std::move(patterns), config))) {
signalPassFailure();
}
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ class RewriterBase : public OpBuilder {

/// Find uses of `from` and replace them with `to`. Also notify the listener
/// about every in-place op modification (for every use that was replaced).
void replaceAllUsesWith(Value from, Value to) {
virtual void replaceAllUsesWith(Value from, Value to) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
Operation *op = operand.getOwner();
modifyOpInPlace(op, [&]() { operand.set(to); });
Expand Down
30 changes: 22 additions & 8 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -854,15 +854,29 @@ class ConversionPatternRewriter final : public PatternRewriter {
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion = nullptr);

/// Replace all the uses of the block argument `from` with `to`. This
/// function supports both 1:1 and 1:N replacements.
/// Replace all the uses of `from` with `to`. The type of `from` and `to` is
/// allowed to differ. The conversion driver will try to reconcile all type
/// mismatches that still exist at the end of the conversion with
/// materializations. This function supports both 1:1 and 1:N replacements.
///
/// Note: If `allowPatternRollback` is set to "true", this function replaces
/// all current and future uses of the block argument. This same block
/// block argument must not be replaced multiple times. Uses are not replaced
/// immediately but in a delayed fashion. Patterns may still see the original
/// uses when inspecting IR.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
/// Note: If `allowPatternRollback` is set to "true", this function behaves
/// slightly different:
///
/// 1. All current and future uses of `from` are replaced. The same value must
/// not be replaced multiple times. That's an API violation.
/// 2. Uses are not replaced immediately but in a delayed fashion. Patterns
/// may still see the original uses when inspecting IR.
/// 3. Uses within the same block that appear before the defining operation
/// of the replacement value are not replaced. This allows users to
/// perform certain replaceAllUsesExcept-style replacements, even though
/// such API is not directly supported.
///
/// Note: In an attempt to align the ConversionPatternRewriter and
/// RewriterBase APIs, (3) may be removed in the future.
void replaceAllUsesWith(Value from, ValueRange to);
void replaceAllUsesWith(Value from, Value to) override {
replaceAllUsesWith(from, ValueRange{to});
}

/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ static void restoreByValRefArgumentType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());

Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
rewriter.replaceUsesOfBlockArgument(arg, valueArg);
rewriter.replaceAllUsesWith(arg, valueArg);
}
}

Expand Down
Loading