Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 69 additions & 66 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
scf::ExecuteRegionOp executeRegionOp =
b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
scf::ExecuteRegionOp::create(b, op->getLoc(), op->getResultTypes());
{
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
Expand All @@ -169,7 +169,7 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
assert(clonedRegion.empty() && "expected empty region");
b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
clonedRegion.end());
b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
scf::YieldOp::create(b, op->getLoc(), clonedOp->getResults());
}
b.replaceOp(op, executeRegionOp.getResults());
return executeRegionOp;
Expand Down
36 changes: 18 additions & 18 deletions mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
// iter_arg's layout map must be changed (see uses of `castBuffer`).
assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
"scf.while op bufferization: cast incompatible");
return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
return memref::CastOp::create(b, buffer.getLoc(), type, buffer).getResult();
}

/// Helper function for loop bufferization. Return "true" if the given value
Expand Down Expand Up @@ -189,7 +189,7 @@ struct ExecuteRegionOpInterface

// Create new op and move over region.
auto newOp =
rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
scf::ExecuteRegionOp::create(rewriter, op->getLoc(), newResultTypes);
newOp.getRegion().takeBody(executeRegionOp.getRegion());

// Bufferize every block.
Expand All @@ -203,8 +203,8 @@ struct ExecuteRegionOpInterface
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
if (isa<TensorType>(it.value())) {
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
executeRegionOp.getLoc(), it.value(),
newResults.push_back(bufferization::ToTensorOp::create(
rewriter, executeRegionOp.getLoc(), it.value(),
newOp->getResult(it.index())));
} else {
newResults.push_back(newOp->getResult(it.index()));
Expand Down Expand Up @@ -258,9 +258,9 @@ struct IfOpInterface

// Create new op.
rewriter.setInsertionPoint(ifOp);
auto newIfOp =
rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
/*withElseRegion=*/true);
auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), newTypes,
ifOp.getCondition(),
/*withElseRegion=*/true);

// Move over then/else blocks.
rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
Expand Down Expand Up @@ -372,9 +372,9 @@ struct IndexSwitchOpInterface

// Create new op.
rewriter.setInsertionPoint(switchOp);
auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
switchOp.getCases().size());
auto newSwitchOp = scf::IndexSwitchOp::create(
rewriter, switchOp.getLoc(), newTypes, switchOp.getArg(),
switchOp.getCases(), switchOp.getCases().size());

// Move over blocks.
for (auto [src, dest] :
Expand Down Expand Up @@ -767,8 +767,8 @@ struct ForOpInterface
}

// Construct a new scf.for op with memref instead of tensor values.
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
auto newForOp = scf::ForOp::create(
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), castedInitArgs);
newForOp->setAttrs(forOp->getAttrs());
Block *loopBody = newForOp.getBody();
Expand Down Expand Up @@ -1003,8 +1003,8 @@ struct WhileOpInterface
// Construct a new scf.while op with memref instead of tensor values.
ValueRange argsRangeBefore(castedInitArgs);
TypeRange argsTypesBefore(argsRangeBefore);
auto newWhileOp = rewriter.create<scf::WhileOp>(
whileOp.getLoc(), argsTypesAfter, castedInitArgs);
auto newWhileOp = scf::WhileOp::create(rewriter, whileOp.getLoc(),
argsTypesAfter, castedInitArgs);

// Add before/after regions to the new op.
SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
Expand Down Expand Up @@ -1263,17 +1263,17 @@ struct ForallOpInterface
forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
BlockArgument bbArg = std::get<0>(it);
Value buffer = std::get<1>(it);
Value bufferAsTensor = rewriter.create<ToTensorOp>(
forallOp.getLoc(), bbArg.getType(), buffer);
Value bufferAsTensor = ToTensorOp::create(rewriter, forallOp.getLoc(),
bbArg.getType(), buffer);
bbArg.replaceAllUsesWith(bufferAsTensor);
}

// Create new ForallOp without any results and drop the automatically
// introduced terminator.
rewriter.setInsertionPoint(forallOp);
ForallOp newForallOp;
newForallOp = rewriter.create<ForallOp>(
forallOp.getLoc(), forallOp.getMixedLowerBound(),
newForallOp = ForallOp::create(
rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
/*outputs=*/ValueRange(), forallOp.getMapping());

Expand Down
17 changes: 9 additions & 8 deletions mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,19 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
SmallVector<Value> initArgs;
initArgs.push_back(forOp.getLowerBound());
llvm::append_range(initArgs, forOp.getInitArgs());
auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
forOp->getAttrs());
auto whileOp = WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs,
forOp->getAttrs());

// 'before' region contains the loop condition and forwarding of iteration
// arguments to the 'after' region.
auto *beforeBlock = rewriter.createBlock(
&whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
auto cmpOp = rewriter.create<arith::CmpIOp>(
whileOp.getLoc(), arith::CmpIPredicate::slt,
auto cmpOp = arith::CmpIOp::create(
rewriter, whileOp.getLoc(), arith::CmpIPredicate::slt,
beforeBlock->getArgument(0), forOp.getUpperBound());
rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
beforeBlock->getArguments());
scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(),
beforeBlock->getArguments());

// Inline for-loop body into an executeRegion operation in the "after"
// region. The return type of the execRegionOp does not contain the
Expand All @@ -72,8 +72,9 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {

// Add induction variable incrementation
rewriter.setInsertionPointToEnd(afterBlock);
auto ivIncOp = rewriter.create<arith::AddIOp>(
whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep());
auto ivIncOp =
arith::AddIOp::create(rewriter, whileOp.getLoc(),
afterBlock->getArgument(0), forOp.getStep());

// Rewrite uses of the for-loop block arguments to the new while-loop
// "after" arguments
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
SmallVector<Value> steps = forallOp.getStep(rewriter);

// Create empty scf.parallel op.
auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
auto parallelOp = scf::ParallelOp::create(rewriter, loc, lbs, ubs, steps);
rewriter.eraseBlock(&parallelOp.getRegion().front());
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
parallelOp.getRegion().begin());
Expand Down
125 changes: 64 additions & 61 deletions mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,25 +279,25 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
if (dynamicLoop) {
Type t = ub.getType();
// pred = ub > lb + (i * step)
Value iv = rewriter.create<arith::AddIOp>(
loc, lb,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(t, i))));
predicates[i] = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, iv, ub);
Value iv = arith::AddIOp::create(
rewriter, loc, lb,
arith::MulIOp::create(
rewriter, loc, step,
arith::ConstantOp::create(rewriter, loc,
rewriter.getIntegerAttr(t, i))));
predicates[i] = arith::CmpIOp::create(rewriter, loc,
arith::CmpIPredicate::slt, iv, ub);
}

// special handling for induction variable as the increment is implicit.
// iv = lb + i * step
Type t = lb.getType();
Value iv = rewriter.create<arith::AddIOp>(
loc, lb,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::ConstantOp>(loc,
rewriter.getIntegerAttr(t, i))));
Value iv = arith::AddIOp::create(
rewriter, loc, lb,
arith::MulIOp::create(
rewriter, loc, step,
arith::ConstantOp::create(rewriter, loc,
rewriter.getIntegerAttr(t, i))));
setValueMapping(forOp.getInductionVar(), iv, i);
for (Operation *op : opOrder) {
if (stages[op] > i)
Expand Down Expand Up @@ -332,8 +332,8 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
Value prevValue = valueMapping
[forOp.getRegionIterArgs()[operand.getOperandNumber()]]
[i - stages[op]];
source = rewriter.create<arith::SelectOp>(
loc, predicates[predicateIdx], source, prevValue);
source = arith::SelectOp::create(
rewriter, loc, predicates[predicateIdx], source, prevValue);
}
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
source, i - stages[op] + 1);
Expand Down Expand Up @@ -444,15 +444,15 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
Type t = ub.getType();
Location loc = forOp.getLoc();
// newUb = ub - maxStage * step
Value maxStageValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(t, maxStage));
Value maxStageValue = arith::ConstantOp::create(
rewriter, loc, rewriter.getIntegerAttr(t, maxStage));
Value maxStageByStep =
rewriter.create<arith::MulIOp>(loc, step, maxStageValue);
newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep);
arith::MulIOp::create(rewriter, loc, step, maxStageValue);
newUb = arith::SubIOp::create(rewriter, loc, ub, maxStageByStep);
}
auto newForOp =
rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
forOp.getStep(), newLoopArg);
scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), newUb,
forOp.getStep(), newLoopArg);
// When there are no iter args, the loop body terminator will be created.
// Since we always create it below, remove the terminator if it was created.
if (!newForOp.getBody()->empty())
Expand Down Expand Up @@ -483,16 +483,17 @@ LogicalResult LoopPipelinerInternal::createKernel(
Type t = ub.getType();
for (unsigned i = 0; i < maxStage; i++) {
// c = ub - (maxStage - i) * step
Value c = rewriter.create<arith::SubIOp>(
loc, ub,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));

Value pred = rewriter.create<arith::CmpIOp>(
newForOp.getLoc(), arith::CmpIPredicate::slt,
newForOp.getInductionVar(), c);
Value c = arith::SubIOp::create(
rewriter, loc, ub,
arith::MulIOp::create(
rewriter, loc, step,
arith::ConstantOp::create(
rewriter, loc,
rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));

Value pred = arith::CmpIOp::create(rewriter, newForOp.getLoc(),
arith::CmpIPredicate::slt,
newForOp.getInductionVar(), c);
predicates[i] = pred;
}
}
Expand All @@ -515,13 +516,13 @@ LogicalResult LoopPipelinerInternal::createKernel(

// offset = (maxStage - stages[op]) * step
Type t = step.getType();
Value offset = rewriter.create<arith::MulIOp>(
forOp.getLoc(), step,
rewriter.create<arith::ConstantOp>(
forOp.getLoc(),
Value offset = arith::MulIOp::create(
rewriter, forOp.getLoc(), step,
arith::ConstantOp::create(
rewriter, forOp.getLoc(),
rewriter.getIntegerAttr(t, maxStage - stages[op])));
Value iv = rewriter.create<arith::AddIOp>(
forOp.getLoc(), newForOp.getInductionVar(), offset);
Value iv = arith::AddIOp::create(rewriter, forOp.getLoc(),
newForOp.getInductionVar(), offset);
nestedNewOp->setOperand(operand->getOperandNumber(), iv);
rewriter.setInsertionPointAfter(newOp);
continue;
Expand Down Expand Up @@ -594,8 +595,8 @@ LogicalResult LoopPipelinerInternal::createKernel(
auto defStage = stages.find(def);
if (defStage != stages.end() && defStage->second < maxStage) {
Value pred = predicates[defStage->second];
source = rewriter.create<arith::SelectOp>(
pred.getLoc(), pred, source,
source = arith::SelectOp::create(
rewriter, pred.getLoc(), pred, source,
newForOp.getBody()
->getArguments()[yieldOperand.getOperandNumber() + 1]);
}
Expand Down Expand Up @@ -638,7 +639,7 @@ LogicalResult LoopPipelinerInternal::createKernel(
maxStage - defStage->second + 1);
}
}
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
scf::YieldOp::create(rewriter, forOp.getLoc(), yieldOperands);
return success();
}

Expand All @@ -652,51 +653,53 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
// removed by dead code if not used.

auto createConst = [&](int v) {
return rewriter.create<arith::ConstantOp>(loc,
rewriter.getIntegerAttr(t, v));
return arith::ConstantOp::create(rewriter, loc,
rewriter.getIntegerAttr(t, v));
};

// total_iterations = cdiv(range_diff, step);
// - range_diff = ub - lb
// - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
Value zero = createConst(0);
Value one = createConst(1);
Value stepLessZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, step, zero);
Value stepDecr =
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1));
Value stepLessZero = arith::CmpIOp::create(
rewriter, loc, arith::CmpIPredicate::slt, step, zero);
Value stepDecr = arith::SelectOp::create(rewriter, loc, stepLessZero, one,
createConst(-1));

Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
Value rangeDiff = arith::SubIOp::create(rewriter, loc, ub, lb);
Value rangeIncrStep = arith::AddIOp::create(rewriter, loc, rangeDiff, step);
Value rangeDecr =
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
arith::AddIOp::create(rewriter, loc, rangeIncrStep, stepDecr);
Value totalIterations =
arith::DivSIOp::create(rewriter, loc, rangeDecr, step);

// If total_iters < max_stage, start the epilogue at zero to match the
// ramp-up in the prologue.
// start_iter = max(0, total_iters - max_stage)
Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations,
createConst(maxStage));
iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
Value iterI = arith::SubIOp::create(rewriter, loc, totalIterations,
createConst(maxStage));
iterI = arith::MaxSIOp::create(rewriter, loc, zero, iterI);

// Capture predicates for dynamic loops.
SmallVector<Value> predicates(maxStage + 1);

for (int64_t i = 1; i <= maxStage; i++) {
// newLastIter = lb + step * iterI
Value newlastIter = rewriter.create<arith::AddIOp>(
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
Value newlastIter = arith::AddIOp::create(
rewriter, loc, lb, arith::MulIOp::create(rewriter, loc, step, iterI));

setValueMapping(forOp.getInductionVar(), newlastIter, i);

// increment to next iterI
iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
iterI = arith::AddIOp::create(rewriter, loc, iterI, one);

if (dynamicLoop) {
// Disable stages when `i` is greater than total_iters.
// pred = total_iters >= i
predicates[i] = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, totalIterations, createConst(i));
predicates[i] =
arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
totalIterations, createConst(i));
}
}

Expand Down Expand Up @@ -758,8 +761,8 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
unsigned nextVersion = currentVersion + 1;
Value pred = predicates[currentVersion];
Value prevValue = valueMapping[mapVal][currentVersion];
auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
prevValue);
auto selOp = arith::SelectOp::create(rewriter, loc, pred,
pair.value(), prevValue);
returnValues[ri] = selOp;
if (nextVersion <= maxStage)
setValueMapping(mapVal, selOp, nextVersion);
Expand Down
Loading