diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp index bab9e2852a460..a3e1542e6a947 100644 --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -36,7 +36,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value, if (auto boolValue = dyn_cast(value)) { if (!type.isSignlessInteger(1)) return nullptr; - return b.create(loc, type, boolValue); + return BoolConstantOp::create(b, loc, type, boolValue); } // Materialize integer attributes as `index`. @@ -46,7 +46,7 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value, return nullptr; assert(indexValue.getValue().getBitWidth() == IndexType::kInternalStorageBitWidth); - return b.create(loc, indexValue); + return ConstantOp::create(b, loc, indexValue); } return nullptr; @@ -715,11 +715,11 @@ LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) { index::CmpOp newCmp; if (rhsIsZero) - newCmp = rewriter.create(op.getLoc(), op.getPred(), - subOp.getLhs(), subOp.getRhs()); + newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(), + subOp.getLhs(), subOp.getRhs()); else - newCmp = rewriter.create(op.getLoc(), op.getPred(), - subOp.getRhs(), subOp.getLhs()); + newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(), + subOp.getRhs(), subOp.getLhs()); rewriter.replaceOp(op, newCmp); return success(); } diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp index ff6af63eee531..364e4d385fd62 100644 --- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp @@ -135,8 +135,9 @@ struct GlobalStoreOpInterface auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); auto loc = globalStoreOp.getLoc(); - auto targetMemref = rewriter.create( - loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference()); + auto targetMemref = memref::GetGlobalOp::create( + rewriter, loc, memrefType, + globalStoreOp.getGlobalAttr().getLeafReference()); auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options, state); diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp index 7940ff60a48e7..f52c3f99189d2 100644 --- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp +++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp @@ -60,8 +60,8 @@ struct FoldRank final : public mlir::OpRewritePattern { if (!isa(dltiAttr.value())) return op->emitError() << "Expected an integer attribute for MPI:comm_world_rank"; - Value res = b.create( - op.getLoc(), cast(dltiAttr.value()).getInt()); + Value res = arith::ConstantIndexOp::create( + b, op.getLoc(), cast(dltiAttr.value()).getInt()); if (Value retVal = op.getRetval()) b.replaceOp(op, {retVal, res}); else diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index 26441a9d78658..a21631cbf8510 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -746,7 +746,7 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto poison = dyn_cast(value)) - return builder.create(loc, type, poison); + return ub::PoisonOp::create(builder, loc, type, poison); return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp index 13e2a4b5541b2..31785eb20a642 100644 --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -65,7 +65,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, // Maybe broadcasts scalar value into vector type compatible with `op`. auto bcast = [&](Value value) -> Value { if (auto vec = dyn_cast(op.getType())) - return rewriter.create(op.getLoc(), vec, value); + return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value); return value; }; @@ -84,15 +84,16 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, // Replace `pow(x, 3.0)` with `x * x * x`. if (isExponentValue(3.0)) { Value square = - rewriter.create(op.getLoc(), ValueRange({x, x})); + arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x})); rewriter.replaceOpWithNewOp(op, ValueRange({x, square})); return success(); } // Replace `pow(x, -1.0)` with `1.0 / x`. if (isExponentValue(-1.0)) { - Value one = rewriter.create( - loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0)); + Value one = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0)); rewriter.replaceOpWithNewOp(op, ValueRange({bcast(one), x})); return success(); } @@ -111,8 +112,8 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`. if (isExponentValue(0.75)) { - Value powHalf = rewriter.create(op.getLoc(), x); - Value powQuarter = rewriter.create(op.getLoc(), powHalf); + Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x); + Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf); rewriter.replaceOpWithNewOp(op, ValueRange{powHalf, powQuarter}); return success(); @@ -168,18 +169,18 @@ PowIStrengthReduction::matchAndRewrite( // Maybe broadcasts scalar value into vector type compatible with `op`. auto bcast = [&loc, &op, &rewriter](Value value) -> Value { if (auto vec = dyn_cast(op.getType())) - return rewriter.create(loc, vec, value); + return vector::BroadcastOp::create(rewriter, loc, vec, value); return value; }; Value one; Type opType = getElementTypeOrSelf(op.getType()); if constexpr (std::is_same_v) - one = rewriter.create( - loc, rewriter.getFloatAttr(opType, 1.0)); + one = arith::ConstantOp::create(rewriter, loc, + rewriter.getFloatAttr(opType, 1.0)); else - one = rewriter.create( - loc, rewriter.getIntegerAttr(opType, 1)); + one = arith::ConstantOp::create(rewriter, loc, + rewriter.getIntegerAttr(opType, 1)); // Replace `[fi]powi(x, 0)` with `1`. if (exponentValue == 0) { @@ -208,12 +209,12 @@ PowIStrengthReduction::matchAndRewrite( // with: // (1 / x) * (1 / x) * (1 / x) * ... for (unsigned i = 1; i < exponentValue; ++i) - result = rewriter.create(loc, result, base); + result = MulOpTy::create(rewriter, loc, result, base); // Inverse the base for negative exponent, i.e. for // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`. if (exponentIsNegative) - result = rewriter.create(loc, bcast(one), result); + result = DivOpTy::create(rewriter, loc, bcast(one), result); rewriter.replaceOp(op, result); return success(); diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index bccd486def4bf..5edb6e28fb018 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -32,11 +32,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value, APFloat::rmNearestTiesToEven, &losesInfo); auto attr = b.getFloatAttr(eltType, value); if (auto shapedTy = dyn_cast(type)) { - return b.create(loc, - DenseElementsAttr::get(shapedTy, attr)); + return arith::ConstantOp::create(b, loc, + DenseElementsAttr::get(shapedTy, attr)); } - return b.create(loc, attr); + return arith::ConstantOp::create(b, loc, attr); } static Value createFloatConst(Location loc, Type type, double value, @@ -49,11 +49,11 @@ static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b) { auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value); if (auto shapedTy = dyn_cast(type)) { - return b.create(loc, - DenseElementsAttr::get(shapedTy, attr)); + return arith::ConstantOp::create(b, loc, + DenseElementsAttr::get(shapedTy, attr)); } - return b.create(loc, attr); + return arith::ConstantOp::create(b, loc, attr); } static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { @@ -61,11 +61,11 @@ static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { Type i64Ty = b.getI64Type(); if (auto shapedTy = dyn_cast(opType)) i64Ty = shapedTy.clone(i64Ty); - Value fixedConvert = b.create(i64Ty, operand); - Value fpFixedConvert = b.create(opType, fixedConvert); + Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand); + Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert); // The truncation does not preserve the sign when the truncated // value is -0. So here the sign is copied again. - return b.create(fpFixedConvert, operand); + return math::CopySignOp::create(b, fpFixedConvert, operand); } // sinhf(float x) -> (exp(x) - exp(-x)) / 2 @@ -74,12 +74,12 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) { Value operand = op.getOperand(); Type opType = operand.getType(); - Value exp = b.create(operand); - Value neg = b.create(operand); - Value nexp = b.create(neg); - Value sub = b.create(exp, nexp); + Value exp = math::ExpOp::create(b, operand); + Value neg = arith::NegFOp::create(b, operand); + Value nexp = math::ExpOp::create(b, neg); + Value sub = arith::SubFOp::create(b, exp, nexp); Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); - Value res = b.create(sub, half); + Value res = arith::MulFOp::create(b, sub, half); rewriter.replaceOp(op, res); return success(); } @@ -90,12 +90,12 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) { Value operand = op.getOperand(); Type opType = operand.getType(); - Value exp = b.create(operand); - Value neg = b.create(operand); - Value nexp = b.create(neg); - Value add = b.create(exp, nexp); + Value exp = math::ExpOp::create(b, operand); + Value neg = arith::NegFOp::create(b, operand); + Value nexp = math::ExpOp::create(b, neg); + Value add = arith::AddFOp::create(b, exp, nexp); Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); - Value res = b.create(add, half); + Value res = arith::MulFOp::create(b, add, half); rewriter.replaceOp(op, res); return success(); } @@ -116,23 +116,23 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter); // Compute sign(x) = cast(x < 0) * (-2) + 1 - Value isNegative = rewriter.create( - loc, arith::CmpFPredicate::OLT, op.getOperand(), zero); + Value isNegative = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero); Value isNegativeFloat = - rewriter.create(loc, floatType, isNegative); + arith::UIToFPOp::create(rewriter, loc, floatType, isNegative); Value isNegativeTimesNegTwo = - rewriter.create(loc, isNegativeFloat, negTwo); - Value sign = rewriter.create(loc, isNegativeTimesNegTwo, one); + arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo); + Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one); // Normalize input to positive value: y = sign(x) * x - Value positiveX = rewriter.create(loc, sign, op.getOperand()); + Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand()); // Decompose on normalized input - Value negDoubledX = rewriter.create(loc, negTwo, positiveX); - Value exp2x = rewriter.create(loc, negDoubledX); - Value dividend = rewriter.create(loc, one, exp2x); - Value divisor = rewriter.create(loc, one, exp2x); - Value positiveRes = rewriter.create(loc, dividend, divisor); + Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX); + Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX); + Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x); + Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x); + Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor); // Multiply result by sign(x) to retain signs from negative inputs rewriter.replaceOpWithNewOp(op, sign, positiveRes); @@ -145,9 +145,9 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operand = op.getOperand(); Type type = operand.getType(); - Value sin = b.create(type, operand); - Value cos = b.create(type, operand); - Value div = b.create(type, sin, cos); + Value sin = math::SinOp::create(b, type, operand); + Value cos = math::CosOp::create(b, type, operand); + Value div = arith::DivFOp::create(b, type, sin, cos); rewriter.replaceOp(op, div); return success(); } @@ -160,10 +160,10 @@ static LogicalResult convertAsinhOp(math::AsinhOp op, Type opType = operand.getType(); Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); - Value fma = b.create(operand, operand, one); - Value sqrt = b.create(fma); - Value add = b.create(operand, sqrt); - Value res = b.create(add); + Value fma = math::FmaOp::create(b, operand, operand, one); + Value sqrt = math::SqrtOp::create(b, fma); + Value add = arith::AddFOp::create(b, operand, sqrt); + Value res = math::LogOp::create(b, add); rewriter.replaceOp(op, res); return success(); } @@ -176,10 +176,10 @@ static LogicalResult convertAcoshOp(math::AcoshOp op, Type opType = operand.getType(); Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter); - Value fma = b.create(operand, operand, negOne); - Value sqrt = b.create(fma); - Value add = b.create(operand, sqrt); - Value res = b.create(add); + Value fma = math::FmaOp::create(b, operand, operand, negOne); + Value sqrt = math::SqrtOp::create(b, fma); + Value add = arith::AddFOp::create(b, operand, sqrt); + Value res = math::LogOp::create(b, add); rewriter.replaceOp(op, res); return success(); } @@ -192,13 +192,13 @@ static LogicalResult convertAtanhOp(math::AtanhOp op, Type opType = operand.getType(); Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); - Value add = b.create(operand, one); - Value neg = b.create(operand); - Value sub = b.create(neg, one); - Value div = b.create(add, sub); - Value log = b.create(div); + Value add = arith::AddFOp::create(b, operand, one); + Value neg = arith::NegFOp::create(b, operand); + Value sub = arith::AddFOp::create(b, neg, one); + Value div = arith::DivFOp::create(b, add, sub); + Value log = math::LogOp::create(b, div); Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter); - Value res = b.create(log, half); + Value res = arith::MulFOp::create(b, log, half); rewriter.replaceOp(op, res); return success(); } @@ -209,8 +209,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { Value operandB = op.getOperand(1); Value operandC = op.getOperand(2); Type type = op.getType(); - Value mult = b.create(type, operandA, operandB); - Value add = b.create(type, mult, operandC); + Value mult = arith::MulFOp::create(b, type, operandA, operandB); + Value add = arith::AddFOp::create(b, type, mult, operandC); rewriter.replaceOp(op, add); return success(); } @@ -235,11 +235,12 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); - Value gtCheck = b.create(arith::CmpFPredicate::OGT, operand, - fpFixedConvert); - Value incrValue = b.create(op->getLoc(), gtCheck, one, zero); + Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand, + fpFixedConvert); + Value incrValue = + arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero); - Value ret = b.create(opType, fpFixedConvert, incrValue); + Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue); rewriter.replaceOp(op, ret); return success(); } @@ -257,9 +258,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op, auto convertFPowItoPowf = [&]() -> LogicalResult { Value castPowerToFp = - rewriter.create(op.getLoc(), baseType, power); - Value res = rewriter.create(op.getLoc(), baseType, base, - castPowerToFp); + arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power); + Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base, + castPowerToFp); rewriter.replaceOp(op, res); return success(); }; @@ -280,9 +281,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op, while (absPower > 0) { if (absPower & 1) - res = b.create(baseType, base, res); + res = arith::MulFOp::create(b, baseType, base, res); absPower >>= 1; - base = b.create(baseType, base, base); + base = arith::MulFOp::create(b, baseType, base, base); } // Make sure not to introduce UB in case of negative power. @@ -302,14 +303,14 @@ static LogicalResult convertFPowIOp(math::FPowIOp op, createFloatConst(op->getLoc(), baseType, APFloat::getInf(sem, /*Negative=*/true), rewriter); Value zeroEqCheck = - b.create(arith::CmpFPredicate::OEQ, res, zero); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero); Value negZeroEqCheck = - b.create(arith::CmpFPredicate::OEQ, res, negZero); - res = b.create(baseType, one, res); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero); + res = arith::DivFOp::create(b, baseType, one, res); res = - b.create(op->getLoc(), zeroEqCheck, posInfinity, res); - res = b.create(op->getLoc(), negZeroEqCheck, negInfinity, - res); + arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res); + res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity, + res); } rewriter.replaceOp(op, res); @@ -330,7 +331,7 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { cast(getElementTypeOrSelf(typeB)).getFloatSemantics(); APFloat valueB(sem); auto mulf = [&](Value x, Value y) -> Value { - return b.create(x, y); + return arith::MulFOp::create(b, x, y); }; if (matchPattern(operandB, m_ConstantFloat(&valueB))) { if (valueB.isZero()) { @@ -347,19 +348,19 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { if (valueB.isExactlyValue(-1.0)) { // a^(-1) -> 1 / a Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter); - Value div = b.create(one, operandA); + Value div = arith::DivFOp::create(b, one, operandA); rewriter.replaceOp(op, div); return success(); } if (valueB.isExactlyValue(0.5)) { // a^(1/2) -> sqrt(a) - Value sqrt = b.create(operandA); + Value sqrt = math::SqrtOp::create(b, operandA); rewriter.replaceOp(op, sqrt); return success(); } if (valueB.isExactlyValue(-0.5)) { // a^(-1/2) -> 1 / sqrt(a) - Value rsqrt = b.create(operandA); + Value rsqrt = math::RsqrtOp::create(b, operandA); rewriter.replaceOp(op, rsqrt); return success(); } @@ -372,7 +373,7 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { // a^(-2) -> 1 / (a * a) Value one = createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter); - Value div = b.create(one, mulf(operandA, operandA)); + Value div = arith::DivFOp::create(b, one, mulf(operandA, operandA)); rewriter.replaceOp(op, div); return success(); } @@ -382,9 +383,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { } } - Value logA = b.create(operandA); - Value mult = b.create(operandB, logA); - Value expResult = b.create(mult); + Value logA = math::LogOp::create(b, operandA); + Value mult = arith::MulFOp::create(b, operandB, logA); + Value expResult = math::ExpOp::create(b, mult); rewriter.replaceOp(op, expResult); return success(); } @@ -399,8 +400,8 @@ static LogicalResult convertExp2fOp(math::Exp2Op op, Value operand = op.getOperand(); Type opType = operand.getType(); Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b); - Value mult = b.create(opType, operand, ln2); - Value exp = b.create(op->getLoc(), mult); + Value mult = arith::MulFOp::create(b, opType, operand, ln2); + Value exp = math::ExpOp::create(b, op->getLoc(), mult); rewriter.replaceOp(op, exp); return success(); } @@ -426,8 +427,8 @@ static LogicalResult convertRoundOp(math::RoundOp op, Value c127 = createIntConst(loc, i32Ty, 127, b); Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b); - Value incrValue = b.create(half, operand); - Value add = b.create(opType, operand, incrValue); + Value incrValue = math::CopySignOp::create(b, half, operand); + Value add = arith::AddFOp::create(b, opType, operand, incrValue); Value fpFixedConvert = createTruncatedFPValue(add, b); // There are three cases where adding 0.5 to the value and truncating by @@ -450,15 +451,15 @@ static LogicalResult convertRoundOp(math::RoundOp op, // i64 leading to wrong outputs. // // All three cases satisfy the property `biasedExp >= 23`. - Value operandBitcast = b.create(i32Ty, operand); - Value operandExp = b.create( - b.create(operandBitcast, c23), expMask); - Value operandBiasedExp = b.create(operandExp, c127); - Value isSpecialValOrLargeVal = - b.create(arith::CmpIPredicate::sge, operandBiasedExp, c23); - - Value result = b.create(isSpecialValOrLargeVal, operand, - fpFixedConvert); + Value operandBitcast = arith::BitcastOp::create(b, i32Ty, operand); + Value operandExp = arith::AndIOp::create( + b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask); + Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127); + Value isSpecialValOrLargeVal = arith::CmpIOp::create( + b, arith::CmpIPredicate::sge, operandBiasedExp, c23); + + Value result = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, + fpFixedConvert); rewriter.replaceOp(op, result); return success(); } @@ -488,21 +489,21 @@ static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, auto bits = createIntConst(loc, operandTy, half, rewriter); auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter); - Value pred = - rewriter.create(loc, arith::CmpIPredicate::ule, x, mask); - Value add = rewriter.create(loc, count, bits); - Value shift = rewriter.create(loc, x, bits); + Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ule, + x, mask); + Value add = arith::AddIOp::create(rewriter, loc, count, bits); + Value shift = arith::ShLIOp::create(rewriter, loc, x, bits); - x = rewriter.create(loc, pred, shift, x); - count = rewriter.create(loc, pred, add, count); + x = arith::SelectOp::create(rewriter, loc, pred, shift, x); + count = arith::SelectOp::create(rewriter, loc, pred, add, count); } Value zero = createIntConst(loc, operandTy, 0, rewriter); - Value pred = rewriter.create(loc, arith::CmpIPredicate::eq, - operand, zero); + Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, + operand, zero); Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter); - Value sel = rewriter.create(loc, pred, bwval, count); + Value sel = arith::SelectOp::create(rewriter, loc, pred, bwval, count); rewriter.replaceOp(op, sel); return success(); } @@ -549,29 +550,29 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b); Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b); - Value operandBitcast = b.create(iTy, operand); - Value round = b.create(operand); - Value roundBitcast = b.create(iTy, round); + Value operandBitcast = arith::BitcastOp::create(b, iTy, operand); + Value round = math::RoundOp::create(b, operand); + Value roundBitcast = arith::BitcastOp::create(b, iTy, round); // Get biased exponents for operand and round(operand) - Value operandExp = b.create( - b.create(operandBitcast, c23), expMask); - Value operandBiasedExp = b.create(operandExp, c127); - Value roundExp = b.create( - b.create(roundBitcast, c23), expMask); - Value roundBiasedExp = b.create(roundExp, c127); + Value operandExp = arith::AndIOp::create( + b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask); + Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127); + Value roundExp = arith::AndIOp::create( + b, arith::ShRUIOp::create(b, roundBitcast, c23), expMask); + Value roundBiasedExp = arith::SubIOp::create(b, roundExp, c127); auto safeShiftRight = [&](Value x, Value shift) -> Value { // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior - Value clampedShift = b.create(shift, c0); - clampedShift = b.create(clampedShift, c31); - return b.create(x, clampedShift); + Value clampedShift = arith::MaxSIOp::create(b, shift, c0); + clampedShift = arith::MinSIOp::create(b, clampedShift, c31); + return arith::ShRUIOp::create(b, x, clampedShift); }; auto maskMantissa = [&](Value mantissa, Value mantissaMaskRightShift) -> Value { Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift); - return b.create(mantissa, shiftedMantissaMask); + return arith::AndIOp::create(b, mantissa, shiftedMantissaMask); }; // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring @@ -589,13 +590,13 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, // `biasedExp > 23`, so they get treated as large numbers with no room for // decimals, which are always even. Value roundBiasedExpEq0 = - b.create(arith::CmpIPredicate::eq, roundBiasedExp, c0); - Value roundBiasedExpMinus1 = b.create(roundBiasedExp, c1); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, roundBiasedExp, c0); + Value roundBiasedExpMinus1 = arith::SubIOp::create(b, roundBiasedExp, c1); Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1); - Value roundIsNotEvenOrSpecialVal = b.create( - arith::CmpIPredicate::ne, roundMaskedMantissa, c0); + Value roundIsNotEvenOrSpecialVal = arith::CmpIOp::create( + b, arith::CmpIPredicate::ne, roundMaskedMantissa, c0); roundIsNotEvenOrSpecialVal = - b.create(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0); + arith::OrIOp::create(b, roundIsNotEvenOrSpecialVal, roundBiasedExpEq0); // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive // integers if the bit at index `biasedExp` starting from the left in the @@ -604,37 +605,37 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`, // so these are handled separately. In particular, if `biasedExp == -1`, the // value is halfway if the entire mantissa is zero. - Value operandBiasedExpEqNeg1 = b.create( - arith::CmpIPredicate::eq, operandBiasedExp, cNeg1); - Value expectedOperandMaskedMantissa = b.create( - operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp)); + Value operandBiasedExpEqNeg1 = arith::CmpIOp::create( + b, arith::CmpIPredicate::eq, operandBiasedExp, cNeg1); + Value expectedOperandMaskedMantissa = arith::SelectOp::create( + b, operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp)); Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp); Value operandIsHalfway = - b.create(arith::CmpIPredicate::eq, operandMaskedMantissa, - expectedOperandMaskedMantissa); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, operandMaskedMantissa, + expectedOperandMaskedMantissa); // Ensure `biasedExp` is in the valid range for half values. - Value operandBiasedExpGeNeg1 = b.create( - arith::CmpIPredicate::sge, operandBiasedExp, cNeg1); - Value operandBiasedExpLt23 = - b.create(arith::CmpIPredicate::slt, operandBiasedExp, c23); + Value operandBiasedExpGeNeg1 = arith::CmpIOp::create( + b, arith::CmpIPredicate::sge, operandBiasedExp, cNeg1); + Value operandBiasedExpLt23 = arith::CmpIOp::create( + b, arith::CmpIPredicate::slt, operandBiasedExp, c23); operandIsHalfway = - b.create(operandIsHalfway, operandBiasedExpLt23); + arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpLt23); operandIsHalfway = - b.create(operandIsHalfway, operandBiasedExpGeNeg1); + arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpGeNeg1); // Adjust rounded operand with `round(operand) - sign(operand)` to correct the // case where `round` rounded in the opposite direction of `roundeven`. - Value sign = b.create(c1Float, operand); - Value roundShifted = b.create(round, sign); + Value sign = math::CopySignOp::create(b, c1Float, operand); + Value roundShifted = arith::SubFOp::create(b, round, sign); // If the rounded value is even or a special value, we default to the behavior // of `math.round`. Value needsShift = - b.create(roundIsNotEvenOrSpecialVal, operandIsHalfway); - Value result = b.create(needsShift, roundShifted, round); + arith::AndIOp::create(b, roundIsNotEvenOrSpecialVal, operandIsHalfway); + Value result = arith::SelectOp::create(b, needsShift, roundShifted, round); // The `x - sign` adjustment does not preserve the sign when we are adjusting // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is // rounded to -0.0. - result = b.create(result, operand); + result = math::CopySignOp::create(b, result, operand); rewriter.replaceOp(op, result); return success(); } @@ -656,7 +657,7 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op, Location loc = op->getLoc(); auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter); - auto sqrtOp = rewriter.create(loc, operand); + auto sqrtOp = math::SqrtOp::create(rewriter, loc, operand); rewriter.replaceOpWithNewOp(op, constOneFloat, sqrtOp); return success(); } diff --git a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp index a570ed5118ef0..9d6ad613fc945 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp @@ -73,7 +73,7 @@ void mlir::math::populateExtendToSupportedTypesTypeConverter( }); typeConverter.addTargetMaterialization( [](OpBuilder &b, Type target, ValueRange input, Location loc) { - auto extFOp = b.create(loc, target, input); + auto extFOp = arith::ExtFOp::create(b, loc, target, input); extFOp.setFastmath(arith::FastMathFlags::contract); return extFOp; }); @@ -104,7 +104,7 @@ LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite( for (auto [result, newType, origType] : llvm::zip_equal( results, (*legalized)->getResultTypes(), op->getResultTypes())) { if (newType != origType) { - auto truncFOp = rewriter.create(loc, origType, result); + auto truncFOp = arith::TruncFOp::create(rewriter, loc, origType, result); truncFOp.setFastmath(arith::FastMathFlags::contract); result = truncFOp.getResult(); } diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index dd2dfe372b683..76720cfd4a98c 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -72,7 +72,7 @@ static Value broadcast(ImplicitLocOpBuilder &builder, Value value, std::optional shape) { assert(!isa(value.getType()) && "must be scalar value"); auto type = broadcast(value.getType(), shape); - return shape ? builder.create(type, value) : value; + return shape ? BroadcastOp::create(builder, type, value) : value; } //----------------------------------------------------------------------------// @@ -130,7 +130,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, auto eltType = cast(operand.getType()).getElementType(); auto expandedType = VectorType::get(expandedShape, eltType); expandedOperands[i] = - builder.create(expandedType, operand); + vector::ShapeCastOp::create(builder, expandedType, operand); } } @@ -148,7 +148,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, SmallVector extracted(expandedOperands.size()); for (const auto &tuple : llvm::enumerate(expandedOperands)) extracted[tuple.index()] = - builder.create(tuple.value(), offsets); + vector::ExtractOp::create(builder, tuple.value(), offsets); results[i] = compute(extracted); } @@ -156,16 +156,16 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, // Stitch results together into one large vector. Type resultEltType = cast(results[0].getType()).getElementType(); Type resultExpandedType = VectorType::get(expandedShape, resultEltType); - Value result = builder.create( - resultExpandedType, builder.getZeroAttr(resultExpandedType)); + Value result = arith::ConstantOp::create( + builder, resultExpandedType, builder.getZeroAttr(resultExpandedType)); for (int64_t i = 0; i < maxIndex; ++i) - result = builder.create(results[i], result, - delinearize(i, strides)); + result = vector::InsertOp::create(builder, results[i], result, + delinearize(i, strides)); // Reshape back to the original vector shape. - return builder.create( - VectorType::get(inputShape, resultEltType), result); + return vector::ShapeCastOp::create( + builder, VectorType::get(inputShape, resultEltType), result); } //----------------------------------------------------------------------------// @@ -173,28 +173,28 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, //----------------------------------------------------------------------------// static Value boolCst(ImplicitLocOpBuilder &builder, bool value) { - return builder.create(builder.getBoolAttr(value)); + return arith::ConstantOp::create(builder, builder.getBoolAttr(value)); } static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType) { assert((elementType.isF16() || elementType.isF32()) && "x must be f16 or f32 type."); - return builder.create( - builder.getFloatAttr(elementType, value)); + return arith::ConstantOp::create(builder, + builder.getFloatAttr(elementType, value)); } static Value f32Cst(ImplicitLocOpBuilder &builder, double value) { - return builder.create(builder.getF32FloatAttr(value)); + return arith::ConstantOp::create(builder, builder.getF32FloatAttr(value)); } static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { - return builder.create(builder.getI32IntegerAttr(value)); + return arith::ConstantOp::create(builder, builder.getI32IntegerAttr(value)); } static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { Value i32Value = i32Cst(builder, static_cast(bits)); - return builder.create(builder.getF32Type(), i32Value); + return arith::BitcastOp::create(builder, builder.getF32Type(), i32Value); } //----------------------------------------------------------------------------// @@ -203,15 +203,17 @@ static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { // Return the minimum of the two values or NaN if value is NaN static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) { - return builder.create( - builder.create(arith::CmpFPredicate::ULT, value, bound), + return arith::SelectOp::create( + builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT, value, bound), value, bound); } // Return the maximum of the two values or NaN if value is NaN static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) { - return builder.create( - builder.create(arith::CmpFPredicate::UGT, value, bound), + return arith::SelectOp::create( + builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::UGT, value, bound), value, bound); } @@ -241,24 +243,24 @@ static std::pair frexp(ImplicitLocOpBuilder &builder, Value arg, Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u); // Bitcast to i32 for bitwise operations. - Value i32Half = builder.create(i32, cstHalf); - Value i32InvMantMask = builder.create(i32, cstInvMantMask); - Value i32Arg = builder.create(i32Vec, arg); + Value i32Half = arith::BitcastOp::create(builder, i32, cstHalf); + Value i32InvMantMask = arith::BitcastOp::create(builder, i32, cstInvMantMask); + Value i32Arg = arith::BitcastOp::create(builder, i32Vec, arg); // Compute normalized fraction. - Value tmp0 = builder.create(i32Arg, bcast(i32InvMantMask)); - Value tmp1 = builder.create(tmp0, bcast(i32Half)); - Value normalizedFraction = builder.create(f32Vec, tmp1); + Value tmp0 = arith::AndIOp::create(builder, i32Arg, bcast(i32InvMantMask)); + Value tmp1 = arith::OrIOp::create(builder, tmp0, bcast(i32Half)); + Value normalizedFraction = arith::BitcastOp::create(builder, f32Vec, tmp1); // Compute exponent. - Value arg0 = isPositive ? arg : builder.create(arg); - Value biasedExponentBits = builder.create( - builder.create(i32Vec, arg0), + Value arg0 = isPositive ? arg : math::AbsFOp::create(builder, arg); + Value biasedExponentBits = arith::ShRUIOp::create( + builder, arith::BitcastOp::create(builder, i32Vec, arg0), bcast(i32Cst(builder, 23))); Value biasedExponent = - builder.create(f32Vec, biasedExponentBits); + arith::SIToFPOp::create(builder, f32Vec, biasedExponentBits); Value exponent = - builder.create(biasedExponent, bcast(cst126f)); + arith::SubFOp::create(builder, biasedExponent, bcast(cst126f)); return {normalizedFraction, exponent}; } @@ -278,10 +280,10 @@ static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { // Set the exponent bias to zero. auto bias = bcast(i32Cst(builder, 127)); - Value biasedArg = builder.create(arg, bias); + Value biasedArg = arith::AddIOp::create(builder, arg, bias); Value exp2ValueInt = - builder.create(biasedArg, exponetBitLocation); - Value exp2ValueF32 = builder.create(f32Vec, exp2ValueInt); + arith::ShLIOp::create(builder, biasedArg, exponetBitLocation); + Value exp2ValueF32 = arith::BitcastOp::create(builder, f32Vec, exp2ValueInt); return exp2ValueF32; } @@ -300,10 +302,10 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, if (coeffs.size() == 1) return coeffs[0]; - Value res = builder.create(x, coeffs[coeffs.size() - 1], - coeffs[coeffs.size() - 2]); + Value res = math::FmaOp::create(builder, x, coeffs[coeffs.size() - 1], + coeffs[coeffs.size() - 2]); for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { - res = builder.create(x, res, coeffs[i]); + res = math::FmaOp::create(builder, x, res, coeffs[i]); } return res; } @@ -343,9 +345,9 @@ LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) { Location loc = op->getLoc(); SmallVector operands; for (auto operand : op->getOperands()) - operands.push_back(rewriter.create(loc, newType, operand)); + operands.push_back(arith::ExtFOp::create(rewriter, loc, newType, operand)); auto result = - rewriter.create(loc, TypeRange{newType}, operands, op->getAttrs()); + T::create(rewriter, loc, TypeRange{newType}, operands, op->getAttrs()); rewriter.replaceOpWithNewOp(op, origType, result); return success(); } @@ -393,18 +395,18 @@ AtanApproximation::matchAndRewrite(math::AtanOp op, std::optional shape = vectorShape(op.getOperand()); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - Value abs = builder.create(operand); + Value abs = math::AbsFOp::create(builder, operand); auto one = broadcast(builder, f32Cst(builder, 1.0), shape); // When 0.66 < x <= 2.41 we do (x-1) / (x+1): auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape); Value cmp2 = - builder.create(arith::CmpFPredicate::OGT, abs, twoThirds); - Value addone = builder.create(abs, one); - Value subone = builder.create(abs, one); - Value xnum = builder.create(cmp2, subone, abs); - Value xden = builder.create(cmp2, addone, one); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, twoThirds); + Value addone = arith::AddFOp::create(builder, abs, one); + Value subone = arith::SubFOp::create(builder, abs, one); + Value xnum = arith::SelectOp::create(builder, cmp2, subone, abs); + Value xden = arith::SelectOp::create(builder, cmp2, addone, one); auto bcast = [&](Value value) -> Value { return broadcast(builder, value, shape); @@ -413,12 +415,12 @@ AtanApproximation::matchAndRewrite(math::AtanOp op, // Break into the <= 0.66 or > 2.41 we do x or 1/x: auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880)); Value cmp1 = - builder.create(arith::CmpFPredicate::OGT, abs, tan3pio8); - xnum = builder.create(cmp1, one, xnum); - xden = builder.create(cmp1, abs, xden); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, tan3pio8); + xnum = arith::SelectOp::create(builder, cmp1, one, xnum); + xden = arith::SelectOp::create(builder, cmp1, abs, xden); - Value x = builder.create(xnum, xden); - Value xx = builder.create(x, x); + Value x = arith::DivFOp::create(builder, xnum, xden); + Value xx = arith::MulFOp::create(builder, x, x); // Perform the Taylor series approximation for atan over the range // [0.0, 0.66]. @@ -435,31 +437,31 @@ AtanApproximation::matchAndRewrite(math::AtanOp op, // Apply the polynomial approximation for the numerator: Value n = p0; - n = builder.create(xx, n, p1); - n = builder.create(xx, n, p2); - n = builder.create(xx, n, p3); - n = builder.create(xx, n, p4); - n = builder.create(n, xx); + n = math::FmaOp::create(builder, xx, n, p1); + n = math::FmaOp::create(builder, xx, n, p2); + n = math::FmaOp::create(builder, xx, n, p3); + n = math::FmaOp::create(builder, xx, n, p4); + n = arith::MulFOp::create(builder, n, xx); // Apply the polynomial approximation for the denominator: Value d = q0; - d = builder.create(xx, d, q1); - d = builder.create(xx, d, q2); - d = builder.create(xx, d, q3); - d = builder.create(xx, d, q4); + d = math::FmaOp::create(builder, xx, d, q1); + d = math::FmaOp::create(builder, xx, d, q2); + d = math::FmaOp::create(builder, xx, d, q3); + d = math::FmaOp::create(builder, xx, d, q4); // Compute approximation of theta: - Value ans0 = builder.create(n, d); - ans0 = builder.create(ans0, x, x); + Value ans0 = arith::DivFOp::create(builder, n, d); + ans0 = math::FmaOp::create(builder, ans0, x, x); // Correct for the input mapping's angles: Value mpi4 = bcast(f32Cst(builder, llvm::numbers::pi / 4)); - Value ans2 = builder.create(mpi4, ans0); - Value ans = builder.create(cmp2, ans2, ans0); + Value ans2 = arith::AddFOp::create(builder, mpi4, ans0); + Value ans = arith::SelectOp::create(builder, cmp2, ans2, ans0); Value mpi2 = bcast(f32Cst(builder, llvm::numbers::pi / 2)); - Value ans1 = builder.create(mpi2, ans0); - ans = builder.create(cmp1, ans1, ans); + Value ans1 = arith::SubFOp::create(builder, mpi2, ans0); + ans = arith::SelectOp::create(builder, cmp1, ans1, ans); // Correct for signing of the input. rewriter.replaceOpWithNewOp(op, ans, operand); @@ -492,44 +494,46 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op, std::optional shape = vectorShape(op.getResult()); // Compute atan in the valid range. - auto div = builder.create(y, x); - auto atan = builder.create(div); + auto div = arith::DivFOp::create(builder, y, x); + auto atan = math::AtanOp::create(builder, div); // Determine what the atan would be for a 180 degree rotation. auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape); auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape); - auto addPi = builder.create(atan, pi); - auto subPi = builder.create(atan, pi); + auto addPi = arith::AddFOp::create(builder, atan, pi); + auto subPi = arith::SubFOp::create(builder, atan, pi); auto atanGt = - builder.create(arith::CmpFPredicate::OGT, atan, zero); - auto flippedAtan = builder.create(atanGt, subPi, addPi); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, atan, zero); + auto flippedAtan = arith::SelectOp::create(builder, atanGt, subPi, addPi); // Determine whether to directly use atan or use the 180 degree flip - auto xGt = builder.create(arith::CmpFPredicate::OGT, x, zero); - Value result = builder.create(xGt, atan, flippedAtan); + auto xGt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, x, zero); + Value result = arith::SelectOp::create(builder, xGt, atan, flippedAtan); // Handle x = 0, y > 0 Value xZero = - builder.create(arith::CmpFPredicate::OEQ, x, zero); - Value yGt = builder.create(arith::CmpFPredicate::OGT, y, zero); - Value isHalfPi = builder.create(xZero, yGt); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, x, zero); + Value yGt = + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, y, zero); + Value isHalfPi = arith::AndIOp::create(builder, xZero, yGt); auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape); - result = builder.create(isHalfPi, halfPi, result); + result = arith::SelectOp::create(builder, isHalfPi, halfPi, result); // Handle x = 0, y < 0 - Value yLt = builder.create(arith::CmpFPredicate::OLT, y, zero); - Value isNegativeHalfPiPi = builder.create(xZero, yLt); + Value yLt = + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, y, zero); + Value isNegativeHalfPiPi = arith::AndIOp::create(builder, xZero, yLt); auto negativeHalfPiPi = broadcast(builder, f32Cst(builder, -1.57079632679f), shape); - result = builder.create(isNegativeHalfPiPi, negativeHalfPiPi, - result); + result = arith::SelectOp::create(builder, isNegativeHalfPiPi, + negativeHalfPiPi, result); // Handle x = 0, y = 0; Value yZero = - builder.create(arith::CmpFPredicate::OEQ, y, zero); - Value isNan = builder.create(xZero, yZero); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, y, zero); + Value isNan = arith::AndIOp::create(builder, xZero, yZero); Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape); - result = builder.create(isNan, cstNan, result); + result = arith::SelectOp::create(builder, isNan, cstNan, result); rewriter.replaceOp(op, result); return success(); @@ -569,9 +573,9 @@ TanhApproximation::matchAndRewrite(math::TanhOp op, // Mask for tiny values that are approximated with `operand`. Value tiny = bcast(f32Cst(builder, 0.0004f)); - Value tinyMask = builder.create( - arith::CmpFPredicate::OLT, builder.create(op.getOperand()), - tiny); + Value tinyMask = arith::CmpFOp::create( + builder, arith::CmpFPredicate::OLT, + math::AbsFOp::create(builder, op.getOperand()), tiny); // The monomial coefficients of the numerator polynomial (odd). Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f)); @@ -589,25 +593,25 @@ TanhApproximation::matchAndRewrite(math::TanhOp op, Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f)); // Since the polynomials are odd/even, we need x^2. - Value x2 = builder.create(x, x); + Value x2 = arith::MulFOp::create(builder, x, x); // Evaluate the numerator polynomial p. - Value p = builder.create(x2, alpha13, alpha11); - p = builder.create(x2, p, alpha9); - p = builder.create(x2, p, alpha7); - p = builder.create(x2, p, alpha5); - p = builder.create(x2, p, alpha3); - p = builder.create(x2, p, alpha1); - p = builder.create(x, p); + Value p = math::FmaOp::create(builder, x2, alpha13, alpha11); + p = math::FmaOp::create(builder, x2, p, alpha9); + p = math::FmaOp::create(builder, x2, p, alpha7); + p = math::FmaOp::create(builder, x2, p, alpha5); + p = math::FmaOp::create(builder, x2, p, alpha3); + p = math::FmaOp::create(builder, x2, p, alpha1); + p = arith::MulFOp::create(builder, x, p); // Evaluate the denominator polynomial q. - Value q = builder.create(x2, beta6, beta4); - q = builder.create(x2, q, beta2); - q = builder.create(x2, q, beta0); + Value q = math::FmaOp::create(builder, x2, beta6, beta4); + q = math::FmaOp::create(builder, x2, q, beta2); + q = math::FmaOp::create(builder, x2, q, beta0); // Divide the numerator by the denominator. - Value res = builder.create( - tinyMask, x, builder.create(p, q)); + Value res = arith::SelectOp::create(builder, tinyMask, x, + arith::DivFOp::create(builder, p, q)); rewriter.replaceOp(op, res); @@ -690,57 +694,57 @@ LogApproximationBase::logMatchAndRewrite(Op op, PatternRewriter &rewriter, // e -= 1; // x = x + x - 1.0; // } else { x = x - 1.0; } - Value mask = builder.create(arith::CmpFPredicate::OLT, x, - cstCephesSQRTHF); - Value tmp = builder.create(mask, x, cstZero); + Value mask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, + cstCephesSQRTHF); + Value tmp = arith::SelectOp::create(builder, mask, x, cstZero); - x = builder.create(x, cstOne); - e = builder.create( - e, builder.create(mask, cstOne, cstZero)); - x = builder.create(x, tmp); + x = arith::SubFOp::create(builder, x, cstOne); + e = arith::SubFOp::create( + builder, e, arith::SelectOp::create(builder, mask, cstOne, cstZero)); + x = arith::AddFOp::create(builder, x, tmp); - Value x2 = builder.create(x, x); - Value x3 = builder.create(x2, x); + Value x2 = arith::MulFOp::create(builder, x, x); + Value x3 = arith::MulFOp::create(builder, x2, x); // Evaluate the polynomial approximant of degree 8 in three parts. Value y0, y1, y2; - y0 = builder.create(cstCephesLogP0, x, cstCephesLogP1); - y1 = builder.create(cstCephesLogP3, x, cstCephesLogP4); - y2 = builder.create(cstCephesLogP6, x, cstCephesLogP7); - y0 = builder.create(y0, x, cstCephesLogP2); - y1 = builder.create(y1, x, cstCephesLogP5); - y2 = builder.create(y2, x, cstCephesLogP8); - y0 = builder.create(y0, x3, y1); - y0 = builder.create(y0, x3, y2); - y0 = builder.create(y0, x3); - - y0 = builder.create(cstNegHalf, x2, y0); - x = builder.create(x, y0); + y0 = math::FmaOp::create(builder, cstCephesLogP0, x, cstCephesLogP1); + y1 = math::FmaOp::create(builder, cstCephesLogP3, x, cstCephesLogP4); + y2 = math::FmaOp::create(builder, cstCephesLogP6, x, cstCephesLogP7); + y0 = math::FmaOp::create(builder, y0, x, cstCephesLogP2); + y1 = math::FmaOp::create(builder, y1, x, cstCephesLogP5); + y2 = math::FmaOp::create(builder, y2, x, cstCephesLogP8); + y0 = math::FmaOp::create(builder, y0, x3, y1); + y0 = math::FmaOp::create(builder, y0, x3, y2); + y0 = arith::MulFOp::create(builder, y0, x3); + + y0 = math::FmaOp::create(builder, cstNegHalf, x2, y0); + x = arith::AddFOp::create(builder, x, y0); if (base2) { Value cstLog2e = bcast(f32Cst(builder, static_cast(LOG2E_VALUE))); - x = builder.create(x, cstLog2e, e); + x = math::FmaOp::create(builder, x, cstLog2e, e); } else { Value cstLn2 = bcast(f32Cst(builder, static_cast(LN2_VALUE))); - x = builder.create(e, cstLn2, x); + x = math::FmaOp::create(builder, e, cstLn2, x); } - Value invalidMask = builder.create(arith::CmpFPredicate::ULT, - op.getOperand(), cstZero); - Value zeroMask = builder.create(arith::CmpFPredicate::OEQ, - op.getOperand(), cstZero); - Value posInfMask = builder.create(arith::CmpFPredicate::OEQ, - op.getOperand(), cstPosInf); + Value invalidMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT, + op.getOperand(), cstZero); + Value zeroMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, + op.getOperand(), cstZero); + Value posInfMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, + op.getOperand(), cstPosInf); // Filter out invalid values: // • x == 0 -> -INF // • x < 0 -> NAN // • x == +INF -> +INF - Value aproximation = builder.create( - zeroMask, cstMinusInf, - builder.create( - invalidMask, cstNan, - builder.create(posInfMask, cstPosInf, x))); + Value aproximation = arith::SelectOp::create( + builder, zeroMask, cstMinusInf, + arith::SelectOp::create( + builder, invalidMask, cstNan, + arith::SelectOp::create(builder, posInfMask, cstPosInf, x))); rewriter.replaceOp(op, aproximation); @@ -805,17 +809,18 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op, // "logLarge" below. Value cstOne = bcast(f32Cst(builder, 1.0f)); Value x = op.getOperand(); - Value u = builder.create(x, cstOne); + Value u = arith::AddFOp::create(builder, x, cstOne); Value uSmall = - builder.create(arith::CmpFPredicate::OEQ, u, cstOne); - Value logU = builder.create(u); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, cstOne); + Value logU = math::LogOp::create(builder, u); Value uInf = - builder.create(arith::CmpFPredicate::OEQ, u, logU); - Value logLarge = builder.create( - x, builder.create( - logU, builder.create(u, cstOne))); - Value approximation = builder.create( - builder.create(uSmall, uInf), x, logLarge); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, logU); + Value logLarge = arith::MulFOp::create( + builder, x, + arith::DivFOp::create(builder, logU, + arith::SubFOp::create(builder, u, cstOne))); + Value approximation = arith::SelectOp::create( + builder, arith::OrIOp::create(builder, uSmall, uInf), x, logLarge); rewriter.replaceOp(op, approximation); return success(); } @@ -853,36 +858,37 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op, }; auto fma = [&](Value a, Value b, Value c) -> Value { - return builder.create(a, b, c); + return math::FmaOp::create(builder, a, b, c); }; auto mul = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::MulFOp::create(builder, a, b); }; auto sub = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::SubFOp::create(builder, a, b); }; - auto abs = [&](Value a) -> Value { return builder.create(a); }; + auto abs = [&](Value a) -> Value { return math::AbsFOp::create(builder, a); }; - auto sqrt = [&](Value a) -> Value { return builder.create(a); }; + auto sqrt = [&](Value a) -> Value { + return math::SqrtOp::create(builder, a); + }; auto scopy = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return math::CopySignOp::create(builder, a, b); }; auto sel = [&](Value a, Value b, Value c) -> Value { - return builder.create(a, b, c); + return arith::SelectOp::create(builder, a, b, c); }; Value abso = abs(operand); Value aa = mul(operand, operand); Value opp = sqrt(sub(bcast(floatCst(builder, 1.0, elementType)), aa)); - Value gt = - builder.create(arith::CmpFPredicate::OGT, aa, - bcast(floatCst(builder, 0.5, elementType))); + Value gt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, aa, + bcast(floatCst(builder, 0.5, elementType))); Value x = sel(gt, opp, abso); @@ -948,51 +954,51 @@ AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op, }; auto fma = [&](Value a, Value b, Value c) -> Value { - return builder.create(a, b, c); + return math::FmaOp::create(builder, a, b, c); }; auto mul = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::MulFOp::create(builder, a, b); }; - Value negOperand = builder.create(operand); + Value negOperand = arith::NegFOp::create(builder, operand); Value zero = bcast(floatCst(builder, 0.0, elementType)); Value half = bcast(floatCst(builder, 0.5, elementType)); Value negOne = bcast(floatCst(builder, -1.0, elementType)); Value selR = - builder.create(arith::CmpFPredicate::OGT, operand, zero); - Value r = builder.create(selR, negOperand, operand); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, operand, zero); + Value r = arith::SelectOp::create(builder, selR, negOperand, operand); Value chkConst = bcast(floatCst(builder, -0.5625, elementType)); Value firstPred = - builder.create(arith::CmpFPredicate::OGT, r, chkConst); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, r, chkConst); Value trueVal = fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)), bcast(floatCst(builder, 1.6839188885261840e+0, elementType)), - builder.create(r)); + math::AsinOp::create(builder, r)); - Value falseVal = builder.create(fma(half, r, half)); - falseVal = builder.create(falseVal); + Value falseVal = math::SqrtOp::create(builder, fma(half, r, half)); + falseVal = math::AsinOp::create(builder, falseVal); falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal); - r = builder.create(firstPred, trueVal, falseVal); + r = arith::SelectOp::create(builder, firstPred, trueVal, falseVal); // Check whether the operand lies in between [-1.0, 0.0). - Value greaterThanNegOne = - builder.create(arith::CmpFPredicate::OGE, operand, negOne); + Value greaterThanNegOne = arith::CmpFOp::create( + builder, arith::CmpFPredicate::OGE, operand, negOne); Value lessThanZero = - builder.create(arith::CmpFPredicate::OLT, operand, zero); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero); Value betweenNegOneZero = - builder.create(greaterThanNegOne, lessThanZero); + arith::AndIOp::create(builder, greaterThanNegOne, lessThanZero); trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)), bcast(floatCst(builder, 1.6839188885261840e+0, elementType)), - builder.create(r)); + arith::NegFOp::create(builder, r)); Value finalVal = - builder.create(betweenNegOneZero, trueVal, r); + arith::SelectOp::create(builder, betweenNegOneZero, trueVal, r); rewriter.replaceOp(op, finalVal); return success(); @@ -1075,9 +1081,9 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, bounds[2] = bcast(floatCst(builder, 3.75f, elementType)); Value isNegativeArg = - builder.create(arith::CmpFPredicate::OLT, operand, zero); - Value negArg = builder.create(operand); - Value x = builder.create(isNegativeArg, negArg, operand); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero); + Value negArg = arith::NegFOp::create(builder, operand); + Value x = arith::SelectOp::create(builder, isNegativeArg, negArg, operand); Value offset = offsets[0]; Value p[polyDegree + 1]; @@ -1091,30 +1097,30 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, Value isLessThanBound[intervalsCount]; for (int j = 0; j < intervalsCount - 1; ++j) { isLessThanBound[j] = - builder.create(arith::CmpFPredicate::OLT, x, bounds[j]); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, bounds[j]); for (int i = 0; i <= polyDegree; ++i) { - p[i] = builder.create(isLessThanBound[j], p[i], - pp[j + 1][i]); - q[i] = builder.create(isLessThanBound[j], q[i], - qq[j + 1][i]); + p[i] = arith::SelectOp::create(builder, isLessThanBound[j], p[i], + pp[j + 1][i]); + q[i] = arith::SelectOp::create(builder, isLessThanBound[j], q[i], + qq[j + 1][i]); } - offset = builder.create(isLessThanBound[j], offset, - offsets[j + 1]); + offset = arith::SelectOp::create(builder, isLessThanBound[j], offset, + offsets[j + 1]); } - isLessThanBound[intervalsCount - 1] = builder.create( - arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); + isLessThanBound[intervalsCount - 1] = arith::CmpFOp::create( + builder, arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); Value pPoly = makePolynomialCalculation(builder, p, x); Value qPoly = makePolynomialCalculation(builder, q, x); - Value rationalPoly = builder.create(pPoly, qPoly); - Value formula = builder.create(offset, rationalPoly); - formula = builder.create(isLessThanBound[intervalsCount - 1], - formula, one); + Value rationalPoly = arith::DivFOp::create(builder, pPoly, qPoly); + Value formula = arith::AddFOp::create(builder, offset, rationalPoly); + formula = arith::SelectOp::create( + builder, isLessThanBound[intervalsCount - 1], formula, one); // erf is odd function: erf(x) = -erf(-x). - Value negFormula = builder.create(formula); + Value negFormula = arith::NegFOp::create(builder, formula); Value res = - builder.create(isNegativeArg, negFormula, formula); + arith::SelectOp::create(builder, isNegativeArg, negFormula, formula); rewriter.replaceOp(op, res); @@ -1155,65 +1161,67 @@ ErfcPolynomialApproximation::matchAndRewrite(math::ErfcOp op, Value posInf = bcast(floatCst(builder, INFINITY, et)); Value clampVal = bcast(floatCst(builder, 10.0546875f, et)); - Value a = builder.create(x); - Value p = builder.create(a, pos2); - Value r = builder.create(one, p); - Value q = builder.create(neg4, r, one); - Value t = builder.create(builder.create(q, one), - neg2, a); - Value e = builder.create(builder.create(a), q, t); - q = builder.create(r, e, q); + Value a = math::AbsFOp::create(builder, x); + Value p = arith::AddFOp::create(builder, a, pos2); + Value r = arith::DivFOp::create(builder, one, p); + Value q = math::FmaOp::create(builder, neg4, r, one); + Value t = math::FmaOp::create(builder, arith::AddFOp::create(builder, q, one), + neg2, a); + Value e = + math::FmaOp::create(builder, arith::NegFOp::create(builder, a), q, t); + q = math::FmaOp::create(builder, r, e, q); p = bcast(floatCst(builder, -0x1.a4a000p-12f, et)); // -4.01139259e-4 Value c1 = bcast(floatCst(builder, -0x1.42a260p-10f, et)); // -1.23075210e-3 - p = builder.create(p, q, c1); + p = math::FmaOp::create(builder, p, q, c1); Value c2 = bcast(floatCst(builder, 0x1.585714p-10f, et)); // 1.31355342e-3 - p = builder.create(p, q, c2); + p = math::FmaOp::create(builder, p, q, c2); Value c3 = bcast(floatCst(builder, 0x1.1adcc4p-07f, et)); // 8.63227434e-3 - p = builder.create(p, q, c3); + p = math::FmaOp::create(builder, p, q, c3); Value c4 = bcast(floatCst(builder, -0x1.081b82p-07f, et)); // -8.05991981e-3 - p = builder.create(p, q, c4); + p = math::FmaOp::create(builder, p, q, c4); Value c5 = bcast(floatCst(builder, -0x1.bc0b6ap-05f, et)); // -5.42046614e-2 - p = builder.create(p, q, c5); + p = math::FmaOp::create(builder, p, q, c5); Value c6 = bcast(floatCst(builder, 0x1.4ffc46p-03f, et)); // 1.64055392e-1 - p = builder.create(p, q, c6); + p = math::FmaOp::create(builder, p, q, c6); Value c7 = bcast(floatCst(builder, -0x1.540840p-03f, et)); // -1.66031361e-1 - p = builder.create(p, q, c7); + p = math::FmaOp::create(builder, p, q, c7); Value c8 = bcast(floatCst(builder, -0x1.7bf616p-04f, et)); // -9.27639827e-2 - p = builder.create(p, q, c8); + p = math::FmaOp::create(builder, p, q, c8); Value c9 = bcast(floatCst(builder, 0x1.1ba03ap-02f, et)); // 2.76978403e-1 - p = builder.create(p, q, c9); - - Value d = builder.create(pos2, a, one); - r = builder.create(one, d); - q = builder.create(p, r, r); - Value negfa = builder.create(a); - Value fmaqah = builder.create(q, negfa, onehalf); - Value psubq = builder.create(p, q); - e = builder.create(fmaqah, pos2, psubq); - r = builder.create(e, r, q); - - Value s = builder.create(a, a); - e = builder.create(builder.create(s)); - - t = builder.create(builder.create(a), a, s); - r = builder.create( - r, e, - builder.create(builder.create(r, e), t)); - - Value isNotLessThanInf = builder.create( - builder.create(arith::CmpFPredicate::OLT, a, posInf), + p = math::FmaOp::create(builder, p, q, c9); + + Value d = math::FmaOp::create(builder, pos2, a, one); + r = arith::DivFOp::create(builder, one, d); + q = math::FmaOp::create(builder, p, r, r); + Value negfa = arith::NegFOp::create(builder, a); + Value fmaqah = math::FmaOp::create(builder, q, negfa, onehalf); + Value psubq = arith::SubFOp::create(builder, p, q); + e = math::FmaOp::create(builder, fmaqah, pos2, psubq); + r = math::FmaOp::create(builder, e, r, q); + + Value s = arith::MulFOp::create(builder, a, a); + e = math::ExpOp::create(builder, arith::NegFOp::create(builder, s)); + + t = math::FmaOp::create(builder, arith::NegFOp::create(builder, a), a, s); + r = math::FmaOp::create( + builder, r, e, + arith::MulFOp::create(builder, arith::MulFOp::create(builder, r, e), t)); + + Value isNotLessThanInf = arith::XOrIOp::create( + builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, a, posInf), trueValue); - r = builder.create(isNotLessThanInf, - builder.create(x, x), r); + r = arith::SelectOp::create(builder, isNotLessThanInf, + arith::AddFOp::create(builder, x, x), r); Value isGreaterThanClamp = - builder.create(arith::CmpFPredicate::OGT, a, clampVal); - r = builder.create(isGreaterThanClamp, zero, r); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, a, clampVal); + r = arith::SelectOp::create(builder, isGreaterThanClamp, zero, r); Value isNegative = - builder.create(arith::CmpFPredicate::OLT, x, zero); - r = builder.create( - isNegative, builder.create(pos2, r), r); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, zero); + r = arith::SelectOp::create(builder, isNegative, + arith::SubFOp::create(builder, pos2, r), r); rewriter.replaceOp(op, r); return success(); @@ -1235,8 +1243,9 @@ Value clampWithNormals(ImplicitLocOpBuilder &builder, }; auto selectCmp = [&builder](auto pred, Value value, Value bound) { - return builder.create( - builder.create(pred, value, bound), value, bound); + return arith::SelectOp::create( + builder, arith::CmpFOp::create(builder, pred, value, bound), value, + bound); }; // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs. @@ -1268,17 +1277,17 @@ ExpApproximation::matchAndRewrite(math::ExpOp op, ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto add = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::AddFOp::create(builder, a, b); }; auto bcast = [&](Value value) -> Value { return broadcast(builder, value, shape); }; - auto floor = [&](Value a) { return builder.create(a); }; + auto floor = [&](Value a) { return math::FloorOp::create(builder, a); }; auto fmla = [&](Value a, Value b, Value c) { - return builder.create(a, b, c); + return math::FmaOp::create(builder, a, b, c); }; auto mul = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::MulFOp::create(builder, a, b); }; // Polynomial approximation from Cephes. @@ -1382,7 +1391,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op, // Convert n' to an i32. This is safe because we clamped it above. auto i32Vec = broadcast(builder.getI32Type(), shape); - Value nI32 = builder.create(i32Vec, n); + Value nI32 = arith::FPToSIOp::create(builder, i32Vec, n); // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127. Value pow2 = exp2I32(builder, nI32); @@ -1430,26 +1439,26 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, Value cstOne = bcast(f32Cst(builder, 1.0f)); Value cstNegOne = bcast(f32Cst(builder, -1.0f)); Value x = op.getOperand(); - Value u = builder.create(x); + Value u = math::ExpOp::create(builder, x); Value uEqOneOrNaN = - builder.create(arith::CmpFPredicate::UEQ, u, cstOne); - Value uMinusOne = builder.create(u, cstOne); - Value uMinusOneEqNegOne = builder.create( - arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); + arith::CmpFOp::create(builder, arith::CmpFPredicate::UEQ, u, cstOne); + Value uMinusOne = arith::SubFOp::create(builder, u, cstOne); + Value uMinusOneEqNegOne = arith::CmpFOp::create( + builder, arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); // logU = log(u) ~= x - Value logU = builder.create(u); + Value logU = math::LogOp::create(builder, u); // Detect exp(x) = +inf; written this way to avoid having to form +inf. Value isInf = - builder.create(arith::CmpFPredicate::OEQ, logU, u); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, logU, u); // (u - 1) * (x / ~x) - Value expm1 = builder.create( - uMinusOne, builder.create(x, logU)); - expm1 = builder.create(isInf, u, expm1); - Value approximation = builder.create( - uEqOneOrNaN, x, - builder.create(uMinusOneEqNegOne, cstNegOne, expm1)); + Value expm1 = arith::MulFOp::create(builder, uMinusOne, + arith::DivFOp::create(builder, x, logU)); + expm1 = arith::SelectOp::create(builder, isInf, u, expm1); + Value approximation = arith::SelectOp::create( + builder, uEqOneOrNaN, x, + arith::SelectOp::create(builder, uMinusOneEqNegOne, cstNegOne, expm1)); rewriter.replaceOp(op, approximation); return success(); } @@ -1494,40 +1503,40 @@ LogicalResult SinAndCosApproximation::matchAndRewrite( return broadcast(builder, value, shape); }; auto mul = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::MulFOp::create(builder, a, b); }; auto sub = [&](Value a, Value b) -> Value { - return builder.create(a, b); + return arith::SubFOp::create(builder, a, b); }; - auto floor = [&](Value a) { return builder.create(a); }; + auto floor = [&](Value a) { return math::FloorOp::create(builder, a); }; auto i32Vec = broadcast(builder.getI32Type(), shape); auto fPToSingedInteger = [&](Value a) -> Value { - return builder.create(i32Vec, a); + return arith::FPToSIOp::create(builder, i32Vec, a); }; auto modulo4 = [&](Value a) -> Value { - return builder.create(a, bcast(i32Cst(builder, 3))); + return arith::AndIOp::create(builder, a, bcast(i32Cst(builder, 3))); }; auto isEqualTo = [&](Value a, Value b) -> Value { - return builder.create(arith::CmpIPredicate::eq, a, b); + return arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, a, b); }; auto isGreaterThan = [&](Value a, Value b) -> Value { - return builder.create(arith::CmpIPredicate::sgt, a, b); + return arith::CmpIOp::create(builder, arith::CmpIPredicate::sgt, a, b); }; auto select = [&](Value cond, Value t, Value f) -> Value { - return builder.create(cond, t, f); + return arith::SelectOp::create(builder, cond, t, f); }; auto fmla = [&](Value a, Value b, Value c) { - return builder.create(a, b, c); + return math::FmaOp::create(builder, a, b, c); }; auto bitwiseOr = [&](Value a, Value b) { - return builder.create(a, b); + return arith::OrIOp::create(builder, a, b); }; Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI)); @@ -1624,7 +1633,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op, intTy = broadcast(intTy, shape); auto bconst = [&](TypedAttr attr) -> Value { - Value value = b.create(attr); + Value value = arith::ConstantOp::create(b, attr); return broadcast(b, value, shape); }; @@ -1641,44 +1650,44 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op, // union {int ix; float x;}; // x = x0; // ix = ix/4 + ix/16; - Value absValue = b.create(operand); - Value intValue = b.create(intTy, absValue); - Value divideBy4 = b.create(intValue, intTwo); - Value divideBy16 = b.create(intValue, intFour); - intValue = b.create(divideBy4, divideBy16); + Value absValue = math::AbsFOp::create(b, operand); + Value intValue = arith::BitcastOp::create(b, intTy, absValue); + Value divideBy4 = arith::ShRSIOp::create(b, intValue, intTwo); + Value divideBy16 = arith::ShRSIOp::create(b, intValue, intFour); + intValue = arith::AddIOp::create(b, divideBy4, divideBy16); // ix = ix + ix/16; - divideBy16 = b.create(intValue, intFour); - intValue = b.create(intValue, divideBy16); + divideBy16 = arith::ShRSIOp::create(b, intValue, intFour); + intValue = arith::AddIOp::create(b, intValue, divideBy16); // ix = ix + ix/256; - Value divideBy256 = b.create(intValue, intEight); - intValue = b.create(intValue, divideBy256); + Value divideBy256 = arith::ShRSIOp::create(b, intValue, intEight); + intValue = arith::AddIOp::create(b, intValue, divideBy256); // ix = 0x2a5137a0 + ix; - intValue = b.create(intValue, intMagic); + intValue = arith::AddIOp::create(b, intValue, intMagic); // Perform one newtons step: // x = 0.33333333f*(2.0f*x + x0/(x*x)); - Value floatValue = b.create(floatTy, intValue); - Value squared = b.create(floatValue, floatValue); - Value mulTwo = b.create(floatValue, fpTwo); - Value divSquared = b.create(absValue, squared); - floatValue = b.create(mulTwo, divSquared); - floatValue = b.create(floatValue, fpThird); + Value floatValue = arith::BitcastOp::create(b, floatTy, intValue); + Value squared = arith::MulFOp::create(b, floatValue, floatValue); + Value mulTwo = arith::MulFOp::create(b, floatValue, fpTwo); + Value divSquared = arith::DivFOp::create(b, absValue, squared); + floatValue = arith::AddFOp::create(b, mulTwo, divSquared); + floatValue = arith::MulFOp::create(b, floatValue, fpThird); // x = 0.33333333f*(2.0f*x + x0/(x*x)); - squared = b.create(floatValue, floatValue); - mulTwo = b.create(floatValue, fpTwo); - divSquared = b.create(absValue, squared); - floatValue = b.create(mulTwo, divSquared); - floatValue = b.create(floatValue, fpThird); + squared = arith::MulFOp::create(b, floatValue, floatValue); + mulTwo = arith::MulFOp::create(b, floatValue, fpTwo); + divSquared = arith::DivFOp::create(b, absValue, squared); + floatValue = arith::AddFOp::create(b, mulTwo, divSquared); + floatValue = arith::MulFOp::create(b, floatValue, fpThird); // Check for zero and restore sign. Value isZero = - b.create(arith::CmpFPredicate::OEQ, absValue, fpZero); - floatValue = b.create(isZero, fpZero, floatValue); - floatValue = b.create(floatValue, operand); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absValue, fpZero); + floatValue = arith::SelectOp::create(b, isZero, fpZero, floatValue); + floatValue = math::CopySignOp::create(b, floatValue, operand); rewriter.replaceOp(op, floatValue); return success(); @@ -1719,29 +1728,29 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); - Value negHalf = builder.create(op.getOperand(), cstNegHalf); + Value negHalf = arith::MulFOp::create(builder, op.getOperand(), cstNegHalf); // Select only the inverse sqrt of positive normals (denormals are // flushed to zero). - Value ltMinMask = builder.create( - arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos); - Value infMask = builder.create(arith::CmpFPredicate::OEQ, - op.getOperand(), cstPosInf); - Value notNormalFiniteMask = builder.create(ltMinMask, infMask); + Value ltMinMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, + op.getOperand(), cstMinNormPos); + Value infMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, + op.getOperand(), cstPosInf); + Value notNormalFiniteMask = arith::OrIOp::create(builder, ltMinMask, infMask); // Compute an approximate result. Value yApprox = handleMultidimensionalVectors( builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value { - return builder.create(operands); + return x86vector::RsqrtOp::create(builder, operands); }); // Do a single step of Newton-Raphson iteration to improve the approximation. // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). // It is essential to evaluate the inner term like this because forming // y_n^2 may over- or underflow. - Value inner = builder.create(negHalf, yApprox); - Value fma = builder.create(yApprox, inner, cstOnePointFive); - Value yNewton = builder.create(yApprox, fma); + Value inner = arith::MulFOp::create(builder, negHalf, yApprox); + Value fma = math::FmaOp::create(builder, yApprox, inner, cstOnePointFive); + Value yNewton = arith::MulFOp::create(builder, yApprox, fma); // Select the result of the Newton-Raphson step for positive normal arguments. // For other arguments, choose the output of the intrinsic. This will @@ -1749,7 +1758,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, // x is zero or a positive denormalized float (equivalent to flushing positive // denormalized inputs to zero). Value res = - builder.create(notNormalFiniteMask, yApprox, yNewton); + arith::SelectOp::create(builder, notNormalFiniteMask, yApprox, yNewton); rewriter.replaceOp(op, res); return success(); diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index cf506d1e7812b..ed0df4e8c5812 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -91,7 +91,7 @@ SmallVector mlir::mesh::getMixedAsValues(OpBuilder b, values.emplace_back(*(dyn++)); } else { TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s); - values.emplace_back(b.create(loc, type, val)); + values.emplace_back(arith::ConstantOp::create(b, loc, type, val)); } } return values; @@ -316,10 +316,10 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, if (!newShardOp) { auto shardingOp = - builder.create(operandValue.getLoc(), sharding); - newShardOp = - builder.create(operandValue.getLoc(), operandValue, shardingOp, - /*annotate_for_users*/ false); + ShardingOp::create(builder, operandValue.getLoc(), sharding); + newShardOp = ShardOp::create(builder, operandValue.getLoc(), operandValue, + shardingOp, + /*annotate_for_users*/ false); } operandValue.replaceUsesWithIf( newShardOp, [operandOp, operandValue](OpOperand &use) { @@ -330,9 +330,9 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, return; } - auto newShardOp2 = builder.create(operandValue.getLoc(), newShardOp, - newShardOp.getSharding(), - /*annotate_for_users*/ true); + auto newShardOp2 = ShardOp::create(builder, operandValue.getLoc(), newShardOp, + newShardOp.getSharding(), + /*annotate_for_users*/ true); newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2); } @@ -378,10 +378,10 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding, builder.setInsertionPoint(operandOp); auto shardingOp = - builder.create(operand.get().getLoc(), sharding); + ShardingOp::create(builder, operand.get().getLoc(), sharding); auto newShardOp = - builder.create(operandValue.getLoc(), operandValue, shardingOp, - /*annotate_for_users*/ true); + ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp, + /*annotate_for_users*/ true); IRRewriter rewriter(builder); rewriter.replaceUsesWithIf( operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) { @@ -395,8 +395,8 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding, builder.setInsertionPoint(newShardOp); auto newPreceedingShardOp = - builder.create(operandValue.getLoc(), operandValue, shardingOp, - /*annotate_for_users*/ false); + ShardOp::create(builder, operandValue.getLoc(), operandValue, shardingOp, + /*annotate_for_users*/ false); rewriter.replaceUsesWithIf( newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) { return use.getOwner() == newShardOp.getOperation(); diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp index 9da3c9a3dd160..db5fd6e494da1 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp @@ -91,15 +91,15 @@ struct MeshShapeFolder newShapeOpMeshAxes.push_back(opMeshAxes[i]); } else { // Fold static mesh axes. - newResults[i] = builder.create( - builder.getIndexAttr(meshAxisSize)); + newResults[i] = arith::ConstantOp::create( + builder, builder.getIndexAttr(meshAxisSize)); } } // Leave only the dynamic mesh axes to be queried. if (!newShapeOpMeshAxes.empty()) { MeshShapeOp newShapeOp = - builder.create(mesh.getSymName(), newShapeOpMeshAxes); + MeshShapeOp::create(builder, mesh.getSymName(), newShapeOpMeshAxes); for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) { newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i]; } diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index d7b7234f69347..1e54affa8198f 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -265,7 +265,8 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis); ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis( sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis); - Value allGatherResult = builder.create( + Value allGatherResult = AllGatherOp::create( + builder, RankedTensorType::get(allGatherResultShape.getShape(), allGatherResultShape.getElementType()), mesh.getSymName(), SmallVector({splitMeshAxis}), sourceShard, @@ -273,7 +274,8 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); TypedValue targetShard = cast>( - builder.create(targetShape, allGatherResult).getResult()); + tensor::CastOp::create(builder, targetShape, allGatherResult) + .getResult()); return {targetShard, targetSharding}; } @@ -398,7 +400,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis( sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis, targetTensorAxis); - Value allToAllResult = builder.create( + Value allToAllResult = AllToAllOp::create( + builder, RankedTensorType::get(allToAllResultShape.getShape(), allToAllResultShape.getElementType()), mesh.getSymName(), SmallVector({meshAxis}), sourceShard, @@ -406,7 +409,7 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); TypedValue targetShard = cast>( - builder.create(targetShape, allToAllResult).getResult()); + tensor::CastOp::create(builder, targetShape, allToAllResult).getResult()); return {targetShard, targetSharding}; } @@ -477,15 +480,16 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, // Extract core from source and copy into destination core. auto noVals = ValueRange{}; - auto initVal = builder.create( - sourceShard.getLoc(), outShape, sourceShard.getType().getElementType()); - auto core = builder.create( - sourceShard.getLoc(), + auto initVal = + tensor::EmptyOp::create(builder, sourceShard.getLoc(), outShape, + sourceShard.getType().getElementType()); + auto core = tensor::ExtractSliceOp::create( + builder, sourceShard.getLoc(), RankedTensorType::get(coreShape, sourceShard.getType().getElementType()), sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides); - auto initOprnd = builder.create( - sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs, - coreShape, strides); + auto initOprnd = tensor::InsertSliceOp::create( + builder, sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, + tgtCoreOffs, coreShape, strides); // Finally update the halo. auto updateHaloResult = diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp index f08ef75d8a004..6ae95ae1f8a49 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp @@ -49,10 +49,11 @@ struct ProcessMultiIndexOpLowering ImplicitLocOpBuilder builder(op->getLoc(), rewriter); builder.setInsertionPointAfter(op.getOperation()); - Value linearIndex = builder.create(mesh); - ValueRange meshShape = builder.create(mesh).getResults(); + Value linearIndex = ProcessLinearIndexOp::create(builder, mesh); + ValueRange meshShape = MeshShapeOp::create(builder, mesh).getResults(); SmallVector completeMultiIndex = - builder.create(linearIndex, meshShape) + affine::AffineDelinearizeIndexOp::create(builder, linearIndex, + meshShape) .getMultiIndex(); SmallVector multiIndex; ArrayRef opMeshAxes = op.getAxes(); @@ -101,32 +102,33 @@ struct AllSliceOpLowering ImplicitLocOpBuilder builder(op->getLoc(), rewriter); builder.setInsertionPointAfter(op.getOperation()); - Value zero = builder.create(builder.getIndexAttr(0)); + Value zero = arith::ConstantOp::create(builder, builder.getIndexAttr(0)); Operation::result_range processInGroupMultiIndex = - builder.create(mesh.getSymName(), op.getMeshAxes()) + ProcessMultiIndexOp::create(builder, mesh.getSymName(), + op.getMeshAxes()) .getResults(); Operation::result_range processGroupShape = - builder.create(mesh.getSymName(), op.getMeshAxes()) + MeshShapeOp::create(builder, mesh.getSymName(), op.getMeshAxes()) .getResult(); Value processGroupSize = createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder); int64_t sliceAxis = op.getSliceAxis().getSExtValue(); Value operandSliceAxisSize = - builder.create(op.getOperand(), sliceAxis); + tensor::DimOp::create(builder, op.getOperand(), sliceAxis); Value operandSliceAxisSizeModProcessGroupSize = - builder.create(operandSliceAxisSize, processGroupSize); - Value isTargetShapeExactlyDivisible = builder.create( - arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize, - zero); - builder.create(isTargetShapeExactlyDivisible, - "Slicing a tensor with axis size that is " - "not exactly divisible by the " - "mesh process group size is not supported."); + arith::RemUIOp::create(builder, operandSliceAxisSize, processGroupSize); + Value isTargetShapeExactlyDivisible = + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, + operandSliceAxisSizeModProcessGroupSize, zero); + cf::AssertOp::create(builder, isTargetShapeExactlyDivisible, + "Slicing a tensor with axis size that is " + "not exactly divisible by the " + "mesh process group size is not supported."); Value resultSliceAxisSize = - builder.create(operandSliceAxisSize, processGroupSize); + arith::DivUIOp::create(builder, operandSliceAxisSize, processGroupSize); OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( llvm::to_vector_of(processInGroupMultiIndex), llvm::to_vector_of(processGroupShape), builder); @@ -139,7 +141,7 @@ struct AllSliceOpLowering if (i == sliceAxis) { sizes.emplace_back(resultSliceAxisSize); } else { - Value dimSize = builder.create(op.getOperand(), i); + Value dimSize = tensor::DimOp::create(builder, op.getOperand(), i); sizes.emplace_back(dimSize); } } @@ -152,10 +154,10 @@ struct AllSliceOpLowering resultSliceAxisSize); SmallVector strides( operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1)); - Value slice = builder.create( - op.getOperand(), offsets, sizes, strides); + Value slice = tensor::ExtractSliceOp::create(builder, op.getOperand(), + offsets, sizes, strides); Value newResult = - builder.create(op.getResult().getType(), slice); + tensor::CastOp::create(builder, op.getResult().getType(), slice); rewriter.replaceAllUsesWith(op.getResult(), newResult); return success(); @@ -201,7 +203,7 @@ TypedValue createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef axes, ImplicitLocOpBuilder &builder) { Operation::result_range meshShape = - builder.create(mesh, axes).getResults(); + mesh::MeshShapeOp::create(builder, mesh, axes).getResults(); return cast>(arith::createProduct( builder, builder.getLoc(), llvm::to_vector_of(meshShape), builder.getIndexType())); @@ -212,13 +214,14 @@ createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex, ArrayRef meshAxes, ImplicitLocOpBuilder &builder) { Operation::result_range processGroupShape = - builder.create(mesh, meshAxes).getResult(); + MeshShapeOp::create(builder, mesh, meshAxes).getResult(); OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( llvm::to_vector_of(processInGroupMultiIndex), llvm::to_vector_of(processGroupShape), builder); auto res = dyn_cast(processInGroupLinearIndex); if (!res) - res = builder.create( + res = arith::ConstantIndexOp::create( + builder, cast(cast(processInGroupLinearIndex)).getInt()); return cast>(res); } @@ -227,7 +230,7 @@ TypedValue createProcessLinearIndex(StringRef mesh, ArrayRef meshAxes, ImplicitLocOpBuilder &builder) { return createProcessLinearIndex( - mesh, builder.create(mesh, meshAxes).getResults(), + mesh, ProcessMultiIndexOp::create(builder, mesh, meshAxes).getResults(), meshAxes, builder); } } // namespace mlir::mesh diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index d2c94b124cdfb..5d253c1199dc0 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -333,15 +333,15 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, // srcElement = (pred) ? prevSrcElements : 0; // Location loc = asyncCopyOp->getLoc(); - Value dstElements = - rewriter.create(loc, asyncCopyOp.getDstElementsAttr()); + Value dstElements = arith::ConstantOp::create( + rewriter, loc, asyncCopyOp.getDstElementsAttr()); Value originalSrcElement = asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements; - Value c0Index = rewriter.create(loc, 0); - auto srcElements = rewriter.create( - loc, predicate, originalSrcElement, c0Index); - auto asyncCopyZeroFillOp = rewriter.create( - loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), + Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto srcElements = arith::SelectOp::create(rewriter, loc, predicate, + originalSrcElement, c0Index); + auto asyncCopyZeroFillOp = nvgpu::DeviceAsyncCopyOp::create( + rewriter, loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(), asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements, UnitAttr()); @@ -675,7 +675,7 @@ MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc, for (auto indexing : indexings) { Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row())); Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col())); - auto load = b.create(loc, memref, ValueRange{row, col}); + auto load = memref::LoadOp::create(b, loc, memref, ValueRange{row, col}); res.push_back(load); } return res; @@ -688,7 +688,7 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( Type elementType = getElementTypeOrSelf(memref.getType()); auto vt = VectorType::get(vectorShape, elementType); - Value res = b.create(loc, vt, loads[0]); + Value res = vector::SplatOp::create(b, loc, vt, loads[0]); foreachIndividualVectorElement( res, /*applyFn=*/ @@ -697,7 +697,7 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( }, /*reduceFn=*/ [&](Value v, int64_t linearIdx, ArrayRef indices) { - res = b.create(loc, v, res, indices); + res = vector::InsertOp::create(b, loc, v, res, indices); }); return res; @@ -715,7 +715,7 @@ SmallVector MmaSyncBuilder::buildMemRefStores( Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row())); Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col())); Operation *store = - b.create(loc, val, memref, ValueRange{row, col}); + memref::StoreOp::create(b, loc, val, memref, ValueRange{row, col}); res.push_back(store); } return res; @@ -730,7 +730,7 @@ SmallVector MmaSyncBuilder::buildMmaSyncMemRefStoreOperand( vectorToStore, /*applyFn=*/ [&](Value v, int64_t linearIdx, ArrayRef indices) { - return b.create(loc, vectorToStore, indices); + return vector::ExtractOp::create(b, loc, vectorToStore, indices); }, /*reduceFn=*/ [&](Value v, int64_t linearIdx, ArrayRef indices) { @@ -810,8 +810,8 @@ FailureOr MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { rhsIndexFn, rhsShape); Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef, resIndexFn, resShape); - res = b.create(loc, lhs, rhs, res, info.mmaShape, - info.tf32Enabled); + res = nvgpu::MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, + info.tf32Enabled); buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn, resShape); return res.getDefiningOp(); @@ -832,8 +832,8 @@ DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( } Location loc = linalgOp.getLoc(); // TODO: more robust computation of laneId, for now assume a single warp. - Value laneId = rewriter.create( - loc, rewriter.getIndexType(), gpu::Dimension::x); + Value laneId = gpu::ThreadIdOp::create( + rewriter, loc, rewriter.getIndexType(), gpu::Dimension::x); if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp))) fail = false; } @@ -897,12 +897,12 @@ SmallVector HopperBuilder::buildPredicateLoadsOnThread0( ArrayRef> sharedMemBuffers, TypedValue barrier) { SmallVector loadOps; - Value zero = rewriter.create(loc, 0); - Value tidx = rewriter.create(loc, gpu::Dimension::x); - Value cond = - rewriter.create(loc, arith::CmpIPredicate::eq, tidx, zero); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x); + Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, + tidx, zero); // clang-format off - rewriter.create( + scf::IfOp::create(rewriter, /*location=*/loc, /*conditional=*/cond, /*thenBuilder=*/ @@ -917,14 +917,14 @@ SmallVector HopperBuilder::buildPredicateLoadsOnThread0( // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load. // This may or may not have perf implications. buildBarrierArriveTx(barrier, sizes); - rewriter.create(loc); + scf::YieldOp::create(rewriter, loc); }, /*elseBuilder=*/ [&](OpBuilder &lb, Location loc) { // TODO: is this for no-thread divergence? // Should we just yield the size and hoist? buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0)); - rewriter.create(loc); + scf::YieldOp::create(rewriter, loc); }); // clang-format on return loadOps; @@ -939,14 +939,15 @@ static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) { TypedValue HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) { auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); - Value barrier = rewriter.create( - loc, + Value barrier = nvgpu::MBarrierCreateOp::create( + rewriter, loc, nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace)); - Value zero = rewriter.create(loc, 0); - rewriter.create( - loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), - zero, Value()); - rewriter.create(loc); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + nvgpu::MBarrierInitOp::create( + rewriter, loc, barrier, + getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), zero, + Value()); + gpu::BarrierOp::create(rewriter, loc); return cast>(barrier); } @@ -955,8 +956,8 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue memref, gpu::LaunchOp launchOp) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(launchOp); - Value unrankedMemRef = rewriter.create( - loc, + Value unrankedMemRef = memref::CastOp::create( + rewriter, loc, UnrankedMemRefType::get(memref.getType().getElementType(), memref.getType().getMemorySpace()), memref); @@ -966,8 +967,8 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue memref, getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes); auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); - Value desc = rewriter.create( - loc, + Value desc = nvgpu::TmaCreateDescriptorOp::create( + rewriter, loc, nvgpu::TensorMapDescriptorType::get( rewriter.getContext(), MemRefType::Builder(memref.getType()) @@ -985,10 +986,10 @@ OpFoldResult HopperBuilder::buildTmaAsyncLoad( TypedValue barrier, SmallVectorImpl &loadOps) { MLIRContext *ctx = rewriter.getContext(); - Value zero = rewriter.create(loc, 0); - Operation *loadOp = rewriter.create( - loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero, - Value(), Value()); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Operation *loadOp = nvgpu::TmaAsyncLoadOp::create( + rewriter, loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, + zero, Value(), Value()); loadOps.push_back(loadOp); auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref); SmallVector symbols(mixedSizes.size()); @@ -1012,23 +1013,23 @@ void HopperBuilder::buildBarrierArriveTx( OpFoldResult size = affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes); Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size); - Value zero = rewriter.create(loc, 0); - rewriter.create(loc, barrier, sizeVal, zero, - Value()); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + nvgpu::MBarrierArriveExpectTxOp::create(rewriter, loc, barrier, sizeVal, zero, + Value()); } void HopperBuilder::buildTryWaitParity( TypedValue barrier) { Type i1 = rewriter.getI1Type(); - Value parity = rewriter.create(loc, i1, 0); + Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0); // 10M is an arbitrary, not too small or too big number to specify the number // of ticks before retry. // TODO: hoist this in a default dialect constant. Value ticksBeforeRetry = - rewriter.create(loc, 10000000); - Value zero = rewriter.create(loc, 0); - rewriter.create(loc, barrier, parity, - ticksBeforeRetry, zero); + arith::ConstantIndexOp::create(rewriter, loc, 10000000); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + nvgpu::MBarrierTryWaitParityOp::create(rewriter, loc, barrier, parity, + ticksBeforeRetry, zero); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp index 47a0c7096de95..b392ffeb13de6 100644 --- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp @@ -109,17 +109,17 @@ static Value buildNumReadElements(OpBuilder &b, Location loc, for (auto [pos, sz] : llvm::zip(transferMask->extractPosition, transferMask->createMaskOp->getOperands())) { Value cmp = - b.create(loc, arith::CmpIPredicate::slt, - b.create(loc, pos), sz); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, + arith::ConstantIndexOp::create(b, loc, pos), sz); if (!cond) { cond = cmp; continue; } - cond = b.create(loc, cmp, cond); + cond = arith::AndIOp::create(b, loc, cmp, cond); } - return b.create( - loc, cond, transferMask->createMaskOp->getOperands().back(), - b.create(loc, 0)); + return arith::SelectOp::create( + b, loc, cond, transferMask->createMaskOp->getOperands().back(), + arith::ConstantIndexOp::create(b, loc, 0)); } /// Return "true" if the conversion to async copy is supported by "async copy". @@ -251,8 +251,9 @@ void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op, int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * numElements) / 8; // bypass_l1 only possible with 16 byte transfer. - Value token = rewriter.create( - writeOp->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()), + Value token = nvgpu::DeviceAsyncCopyOp::create( + rewriter, writeOp->getLoc(), + nvgpu::DeviceAsyncTokenType::get(op->getContext()), /*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp), /*src=*/loadBase, /*srcIndices=*/nvgpu::getIndices(readOp), @@ -264,11 +265,11 @@ void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op, } // Create the group and wait for it right after. - Value groupToken = rewriter.create( - op->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()), - tokens); - rewriter.create(op->getLoc(), groupToken, - nullptr); + Value groupToken = nvgpu::DeviceAsyncCreateGroupOp::create( + rewriter, op->getLoc(), + nvgpu::DeviceAsyncTokenType::get(op->getContext()), tokens); + nvgpu::DeviceAsyncWaitOp::create(rewriter, op->getLoc(), groupToken, + nullptr); // Clean up old stores. for (Operation *writeOp : group) rewriter.eraseOp(writeOp); diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp index 44e7fa961da12..957b9632422a6 100644 --- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp @@ -74,27 +74,28 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc, int64_t mask = (1LL << (m - n)) - 1; if (permuteEveryN > 1) mask = mask << llvm::Log2_64(permuteEveryN); - Value srcBits = b.create(loc, mask); - srcBits = b.create(loc, src, srcBits); + Value srcBits = arith::ConstantIndexOp::create(b, loc, mask); + srcBits = arith::AndIOp::create(b, loc, src, srcBits); // Use the src bits to permute the target bits b[N:M] containing the // vector offset. if (permuteEveryN > 1) { int64_t shlBits = n - llvm::Log2_64(permuteEveryN); if (shlBits > 0) { - Value finalShiftVal = b.create(loc, shlBits); + Value finalShiftVal = arith::ConstantIndexOp::create(b, loc, shlBits); srcBits = b.createOrFold(loc, srcBits, finalShiftVal); } else if (shlBits < 0) { - Value finalShiftVal = b.create(loc, -1 * shlBits); + Value finalShiftVal = + arith::ConstantIndexOp::create(b, loc, -1 * shlBits); srcBits = b.createOrFold(loc, srcBits, finalShiftVal); } } else { - Value finalShiftVal = b.create(loc, n); + Value finalShiftVal = arith::ConstantIndexOp::create(b, loc, n); srcBits = b.createOrFold(loc, srcBits, finalShiftVal); } Value permutedVectorIdx = - b.create(loc, indices[tgtDim], srcBits); + arith::XOrIOp::create(b, loc, indices[tgtDim], srcBits); return permutedVectorIdx; } diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index 793db73575b4f..58cd160948f7f 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -72,7 +72,7 @@ Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar, // Create tensor splat auto tensorConstant = - builder.create(loc, scalar, referenceShape); + tensor::SplatOp::create(builder, loc, scalar, referenceShape); return tensorConstant; } @@ -94,22 +94,22 @@ std::pair flattenUnrankedTensor(OpBuilder &builder, Location loc, // Get unranked input shape and total size auto *context = builder.getContext(); auto shapeType = shape::getExtentTensorType(context); - auto inputShape = builder.create(loc, shapeType, input); - Value inputSize = builder.create( - loc, builder.getIndexType(), inputShape); + auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input); + Value inputSize = shape::NumElementsOp::create( + builder, loc, builder.getIndexType(), inputShape); // Turn input size into 1D tensor auto flatShapeType = shape::getExtentTensorType(context, 1); auto flatInputShape = - builder.create(loc, flatShapeType, inputSize); + tensor::FromElementsOp::create(builder, loc, flatShapeType, inputSize); // Reshape input tensor into 1D auto inputType = cast(input.getType()); auto elementType = inputType.getElementType(); auto flatInputType = RankedTensorType::get({ShapedType::kDynamic}, elementType); - auto flatInput = builder.create(loc, flatInputType, input, - flatInputShape); + auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input, + flatInputShape); return std::make_pair(flatInput, inputShape); } @@ -142,39 +142,40 @@ flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input, auto *context = builder.getContext(); auto indexType = builder.getIndexType(); auto shapeType = shape::getExtentTensorType(context); - auto inputShape = builder.create(loc, shapeType, input); + auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input); // Get shape and sizes on left and right of axis - auto axisValue = builder.create(loc, axis); - auto axisNextValue = builder.create(loc, axis + 1); + auto axisValue = arith::ConstantIndexOp::create(builder, loc, axis); + auto axisNextValue = arith::ConstantIndexOp::create(builder, loc, axis + 1); auto shapeLeft = builder .create(loc, TypeRange{shapeType, shapeType}, inputShape, axisValue) .getResult(0); auto sizeLeft = - builder.create(loc, indexType, shapeLeft); + shape::NumElementsOp::create(builder, loc, indexType, shapeLeft); auto shapeRight = builder .create(loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue) .getResult(1); auto sizeRight = - builder.create(loc, indexType, shapeRight); + shape::NumElementsOp::create(builder, loc, indexType, shapeRight); // Compute flat input shape as a 3-element 1D tensor - auto axisSizeValue = builder.create(loc, axisSize); + auto axisSizeValue = arith::ConstantIndexOp::create(builder, loc, axisSize); auto flatShapeType = shape::getExtentTensorType(context, 3); - auto flatInputShape = builder.create( - loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight}); + auto flatInputShape = tensor::FromElementsOp::create( + builder, loc, flatShapeType, + ValueRange{sizeLeft, axisSizeValue, sizeRight}); // Reshape input to 3D tensor auto inputType = cast(input.getType()); auto elementType = inputType.getElementType(); auto flatInputType = RankedTensorType::get( {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType); - auto flatInput = builder.create(loc, flatInputType, input, - flatInputShape); + auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input, + flatInputShape); return std::make_pair(flatInput, inputShape); } @@ -192,8 +193,8 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input, auto inputType = cast(input.getType()); auto elementType = inputType.getElementType(); auto unrankedType = UnrankedTensorType::get(elementType); - return builder.create(loc, unrankedType, input, - inputShape); + return tensor::ReshapeOp::create(builder, loc, unrankedType, input, + inputShape); } // Create a tensor constant containing all scales in a per-channel quantized @@ -215,7 +216,7 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc, auto tensorType = RankedTensorType::get({(int64_t)scales.size()}, expressedType); auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs); - return builder.create(loc, tensorType, scalesAttr); + return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr); } // Create a tensor constant containing all zero points in a per-channel @@ -239,7 +240,7 @@ Value materializePerChannelZeroPoints( auto tensorType = RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType); auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs); - return builder.create(loc, tensorType, zeroPointsAttr); + return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr); } // Create a tensor constant containing all scales in a sub-channel quantized @@ -263,7 +264,7 @@ Value materializeSubChannelScales( auto tensorType = RankedTensorType::get(scales.getType().getShape(), expressedType); auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs); - return builder.create(loc, tensorType, scalesAttr); + return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr); } // Create a tensor constant containing all zero points in a sub-channel @@ -287,7 +288,7 @@ Value materializeSubChannelZeroPoints( auto tensorType = RankedTensorType::get(zeroPoints.getType().getShape(), storageType); auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs); - return builder.create(loc, tensorType, zeroPointsAttr); + return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr); } // Clamp the given scalar or tensor input using the storage bounds encoded in @@ -314,10 +315,10 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, // Materialize bounds auto inputType = input.getType(); auto storageType = quantizedType.getStorageType(); - auto storageMinScalar = builder.create( - loc, storageType, quantizedType.getStorageTypeMin()); - auto storageMaxScalar = builder.create( - loc, storageType, quantizedType.getStorageTypeMax()); + auto storageMinScalar = arith::ConstantIntOp::create( + builder, loc, storageType, quantizedType.getStorageTypeMin()); + auto storageMaxScalar = arith::ConstantIntOp::create( + builder, loc, storageType, quantizedType.getStorageTypeMax()); auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar, inputType, inputShape); auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar, @@ -325,11 +326,11 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, // Clamp if (quantizedType.isSigned()) { - input = builder.create(loc, input, storageMin); - input = builder.create(loc, input, storageMax); + input = arith::MaxSIOp::create(builder, loc, input, storageMin); + input = arith::MinSIOp::create(builder, loc, input, storageMax); } else { - input = builder.create(loc, input, storageMin); - input = builder.create(loc, input, storageMax); + input = arith::MaxUIOp::create(builder, loc, input, storageMin); + input = arith::MinUIOp::create(builder, loc, input, storageMax); } return input; } @@ -338,16 +339,16 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input, Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input, Type resultType, bool isSigned) { if (isSigned) - return builder.create(loc, resultType, input); - return builder.create(loc, resultType, input); + return arith::FPToSIOp::create(builder, loc, resultType, input); + return arith::FPToUIOp::create(builder, loc, resultType, input); } // Emit op 'arith.sitofp' or 'arith.uitofp'. Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input, Type resultType, bool isSigned) { if (isSigned) - return builder.create(loc, resultType, input); - return builder.create(loc, resultType, input); + return arith::SIToFPOp::create(builder, loc, resultType, input); + return arith::UIToFPOp::create(builder, loc, resultType, input); } // Quantize a scalar or ranked tensor value. The stored value is clamped using @@ -362,7 +363,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input, scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape); // Scale input - auto scaledValue = builder.create(loc, input, scale); + auto scaledValue = arith::DivFOp::create(builder, loc, input, scale); // Skip unnecessary computations if no zero point is given Value storedValueFloat = scaledValue; @@ -377,7 +378,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input, // Add zero point to stored value storedValueFloat = - builder.create(loc, scaledValue, zeroPoint); + arith::AddFOp::create(builder, loc, scaledValue, zeroPoint); } // Convert stored value to storage type @@ -418,11 +419,11 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input, quantizedType.isSigned()); // Subtract zero point to stored value - result = builder.create(loc, result, zeroPoint); + result = arith::SubFOp::create(builder, loc, result, zeroPoint); } // Multiply by scale - result = builder.create(loc, result, scale); + result = arith::MulFOp::create(builder, loc, result, scale); return result; } @@ -477,11 +478,12 @@ Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op, auto storageType = quantizedType.getStorageType(); auto scaleAttr = builder.getFloatAttr(expressedType, quantizedType.getScale()); - auto scale = builder.create(loc, expressedType, scaleAttr); + auto scale = + arith::ConstantOp::create(builder, loc, expressedType, scaleAttr); auto zeroPointAttr = builder.getIntegerAttr(storageType, quantizedType.getZeroPoint()); auto zeroPoint = - builder.create(loc, storageType, zeroPointAttr); + arith::ConstantOp::create(builder, loc, storageType, zeroPointAttr); auto inputShape = getScalarOrTensorShape(builder, loc, input); return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint, @@ -546,7 +548,7 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, ? quantizedType.getStorageType() : quantizedType.getExpressedType(); auto initShape = tensor::getMixedSizes(builder, loc, input); - Value init = builder.create(loc, initShape, elementType); + Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType); SmallVector iteratorTypes(inputRank, utils::IteratorType::parallel); @@ -572,7 +574,7 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, convertRanked(builder, loc, op, input, {}, scale, zeroPoint, quantizedType); - builder.create(loc, result); + linalg::YieldOp::create(builder, loc, result); }) .getResult(0); @@ -642,7 +644,7 @@ Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op, ? quantizedType.getStorageType() : quantizedType.getExpressedType(); auto initShape = tensor::getMixedSizes(builder, loc, input); - Value init = builder.create(loc, initShape, elementType); + Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType); SmallVector iteratorTypes(inputRank, utils::IteratorType::parallel); @@ -675,7 +677,7 @@ Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op, convertRanked(builder, loc, op, input, {}, scale, zeroPoint, quantizedType); - builder.create(loc, result); + linalg::YieldOp::create(builder, loc, result); }) .getResult(0); @@ -729,8 +731,8 @@ struct DequantizeCastOpConversion // Convert quantized input to storage type auto storageScalarOrTensorType = getScalarOrTensorType(quantizedType.getStorageType(), input.getType()); - input = rewriter.create( - loc, storageScalarOrTensorType, input); + input = quant::StorageCastOp::create(rewriter, loc, + storageScalarOrTensorType, input); auto result = convertQuantized(rewriter, loc, op, input, quantizedType); diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp index 920b6ecb01d47..1ffb18fb7ab96 100644 --- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp @@ -41,8 +41,8 @@ class QuantizedTypeConverter : public TypeConverter { static Value materializeConversion(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { - return builder.create(loc, type, - llvm::getSingleElement(inputs)); + return quant::StorageCastOp::create(builder, loc, type, + llvm::getSingleElement(inputs)); } public: