Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ using namespace mlir;
using namespace mlir::vector;

namespace {
/// Progressive lowering of BroadcastOp.

/// Convert a vector.broadcast with a vector operand to a lower rank
/// vector.broadcast. vector.broadcast with a scalar operand is expected to be
/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
public:
using OpRewritePattern::OpRewritePattern;
Expand All @@ -40,20 +43,23 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
Type eltType = dstType.getElementType();

// Scalar to any vector can use splat.
if (!srcType) {
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
return success();
}
// A broadcast from a scalar is considered to be in the lowered form.
if (!srcType)
return rewriter.notifyMatchFailure(
op, "broadcast from scalar already in lowered form");

// Determine rank of source and destination.
int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();

// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
// Here we are broadcasting to a rank-1 vector. Ensure that the source is a
// scalar.
if (srcRank <= 1 && dstRank == 1) {
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
SmallVector<int64_t> fullRankPosition(srcRank, 0);
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use rewriter.create<...> here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't we moving to dialect::Op::create instead? #147168

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. It'll take some getting used to..

fullRankPosition);
assert(!isa<VectorType>(ext.getType()) && "expected scalar");
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext);
return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering
read, "vector type is not rank 1, can't create masked load, needs "
"VectorToSCF");

Value fill = vector::SplatOp::create(
Value fill = vector::BroadcastOp::create(
rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
res = vector::MaskedLoadOp::create(
rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ class DecomposeNDExtractStridedSlice
// Extract/insert on a lower ranked extract strided slice op.
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
rewriter.getZeroAttr(elemType));
Value res = SplatOp::create(rewriter, loc, dstType, zero);
Value res = BroadcastOp::create(rewriter, loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);
Expand Down
61 changes: 39 additions & 22 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {

Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
rewriter.getZeroAttr(elemType));
Value res = SplatOp::create(rewriter, loc, castDstType, zero);
Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);

SmallVector<int64_t> sliceShape = {castDstLastDim};
SmallVector<int64_t> strides = {1};
Expand Down Expand Up @@ -987,6 +987,23 @@ static Type cloneOrReplace(Type type, Type newElementType) {
return newElementType;
}

/// If `value` is the result of a splat or broadcast operation, return the input
/// of the splat/broadcast operation.
static Value getBroadcastLikeSource(Value value) {

Operation *op = value.getDefiningOp();
if (!op)
return {};

if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
return broadcast.getSource();

if (auto splat = dyn_cast<vector::SplatOp>(op))
return splat.getInput();

return {};
}

/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
///
/// Example:
Expand Down Expand Up @@ -1026,39 +1043,37 @@ struct ReorderElementwiseOpsOnBroadcast final
}

Type resultElemType = resultType.getElementType();

// Get the type of the first non-constant operand
Operation *firstBroadcastOrSplat = nullptr;
Value splatSource;
for (Value operand : op->getOperands()) {
Operation *definingOp = operand.getDefiningOp();
if (!definingOp)
return failure();
if (definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
if (!isa<vector::BroadcastOp, vector::SplatOp>(*definingOp))
return failure();
firstBroadcastOrSplat = definingOp;
splatSource = getBroadcastLikeSource(operand);
break;
}
if (!firstBroadcastOrSplat)
if (!splatSource)
return failure();
Type unbroadcastResultType = cloneOrReplace(
firstBroadcastOrSplat->getOperand(0).getType(), resultElemType);
Type unbroadcastResultType =
cloneOrReplace(splatSource.getType(), resultElemType);

// Make sure that all operands are broadcast from identically-shaped types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
if (!llvm::all_of(op->getOperands(), [&unbroadcastResultType](Value val) {
if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
return haveSameShapeAndScaling(bcastOp.getOperand().getType(),
unbroadcastResultType);
if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
return haveSameShapeAndScaling(splatOp.getOperand().getType(),
unbroadcastResultType);
if (!llvm::all_of(op->getOperands(), [splatSource](Value val) {
if (auto source = getBroadcastLikeSource(val))
return haveSameShapeAndScaling(source.getType(),
splatSource.getType());
SplatElementsAttr splatConst;
return matchPattern(val, m_Constant(&splatConst));
})) {
return failure();
return rewriter.notifyMatchFailure(
op,
"not all operands are constants or broadcasts from the same type");
}

// Collect the source values before broadcasting
Expand Down Expand Up @@ -1287,15 +1302,17 @@ class StoreOpFromSplatOrBroadcast final
return rewriter.notifyMatchFailure(
op, "only 1-element vectors are supported");

Operation *splat = op.getValueToStore().getDefiningOp();
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
Value toStore = op.getValueToStore();
Value source = getBroadcastLikeSource(toStore);
if (!source)
return rewriter.notifyMatchFailure(
op, "value to store is not from a broadcast");

// Checking for single use so we can remove splat.
Operation *splat = toStore.getDefiningOp();
if (!splat->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");

Value source = splat->getOperand(0);
Value base = op.getBase();
ValueRange indices = op.getIndices();

Expand Down Expand Up @@ -1345,13 +1362,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
// Add in an offset if requested.
if (off) {
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o);
Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
indices = arith::AddIOp::create(rewriter, loc, ov, indices);
}
// Construct the vector comparison.
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
Value bounds =
vector::SplatOp::create(rewriter, loc, indices.getType(), bound);
vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
indices, bounds);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

// CHECK-LABEL: func @broadcast_vec1d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32>
// CHECK: return %[[T0]] : vector<2xf32>

func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
Expand All @@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {

// CHECK-LABEL: func @broadcast_vec2d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
// CHECK: return %[[T0]] : vector<2x3xf32>

func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
Expand All @@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {

// CHECK-LABEL: func @broadcast_vec3d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32>
// CHECK: return %[[T0]] : vector<2x3x4xf32>

func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
Expand Down Expand Up @@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
// CHECK-LABEL: func @broadcast_stretch
// CHECK-SAME: %[[A:.*0]]: vector<1xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32>
// CHECK: return %[[T1]] : vector<4xf32>

func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
Expand All @@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
// CHECK: %[[U0:.*]] = ub.poison : vector<4x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32>
// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
// CHECK: %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32>
// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
// CHECK: %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32>
// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
// CHECK: %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32>
// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
// CHECK: %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32>
// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
// CHECK: return %[[T15]] : vector<4x3xf32>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
// CHECK: return %[[T7]] : vector<2x3xf32>
Expand All @@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
Expand All @@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32>
// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
// CHECK: return %[[T7]] : vector<2x3xi32>
Expand All @@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32>
// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
// CHECK: %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32>
// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32>
// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
Expand All @@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
// CHECK-LABEL: func @axpy_fp(
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32)
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
Expand All @@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32,
// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
Expand All @@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>
// CHECK-LABEL: func @axpy_int(
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32)
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: return %[[T1]] : vector<16xi32>
func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
Expand All @@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32,
// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
// CHECK: return %[[T2]] : vector<16xi32>
Expand Down