diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index cb8e566869cfd..dedc3b3f30201 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -28,7 +28,10 @@ using namespace mlir; using namespace mlir::vector; namespace { -/// Progressive lowering of BroadcastOp. + +/// Convert a vector.broadcast with a vector operand to a lower rank +/// vector.broadcast. vector.broadcast with a scalar operand is expected to be +/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly. class BroadcastOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -40,20 +43,23 @@ class BroadcastOpLowering : public OpRewritePattern { VectorType srcType = dyn_cast(op.getSourceType()); Type eltType = dstType.getElementType(); - // Scalar to any vector can use splat. - if (!srcType) { - rewriter.replaceOpWithNewOp(op, dstType, op.getSource()); - return success(); - } + // A broadcast from a scalar is considered to be in the lowered form. + if (!srcType) + return rewriter.notifyMatchFailure( + op, "broadcast from scalar already in lowered form"); // Determine rank of source and destination. int64_t srcRank = srcType.getRank(); int64_t dstRank = dstType.getRank(); - // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. + // Here we are broadcasting to a rank-1 vector. Ensure that the source is a + // scalar. if (srcRank <= 1 && dstRank == 1) { - Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource()); - rewriter.replaceOpWithNewOp(op, dstType, ext); + SmallVector fullRankPosition(srcRank, 0); + Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), + fullRankPosition); + assert(!isa(ext.getType()) && "expected scalar"); + rewriter.replaceOpWithNewOp(op, dstType, ext); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 4baeb1145d25b..2cf8f0beaa4de 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering read, "vector type is not rank 1, can't create masked load, needs " "VectorToSCF"); - Value fill = vector::SplatOp::create( + Value fill = vector::BroadcastOp::create( rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding()); res = vector::MaskedLoadOp::create( rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(), diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index 72352d72bfe77..cbb9d4bbf0b1f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -303,7 +303,7 @@ class DecomposeNDExtractStridedSlice // Extract/insert on a lower ranked extract strided slice op. Value zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = SplatOp::create(rewriter, loc, dstType, zero); + Value res = BroadcastOp::create(rewriter, loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { Value one = ExtractOp::create(rewriter, loc, op.getVector(), off); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 7500bf7d1d9eb..2269a40ec8ef1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -939,7 +939,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern { Value zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = SplatOp::create(rewriter, loc, castDstType, zero); + Value res = BroadcastOp::create(rewriter, loc, castDstType, zero); SmallVector sliceShape = {castDstLastDim}; SmallVector strides = {1}; @@ -987,6 +987,23 @@ static Type cloneOrReplace(Type type, Type newElementType) { return newElementType; } +/// If `value` is the result of a splat or broadcast operation, return the input +/// of the splat/broadcast operation. +static Value getBroadcastLikeSource(Value value) { + + Operation *op = value.getDefiningOp(); + if (!op) + return {}; + + if (auto broadcast = dyn_cast(op)) + return broadcast.getSource(); + + if (auto splat = dyn_cast(op)) + return splat.getInput(); + + return {}; +} + /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: /// /// Example: @@ -1026,39 +1043,37 @@ struct ReorderElementwiseOpsOnBroadcast final } Type resultElemType = resultType.getElementType(); + // Get the type of the first non-constant operand - Operation *firstBroadcastOrSplat = nullptr; + Value splatSource; for (Value operand : op->getOperands()) { Operation *definingOp = operand.getDefiningOp(); if (!definingOp) return failure(); if (definingOp->hasTrait()) continue; - if (!isa(*definingOp)) - return failure(); - firstBroadcastOrSplat = definingOp; + splatSource = getBroadcastLikeSource(operand); break; } - if (!firstBroadcastOrSplat) + if (!splatSource) return failure(); - Type unbroadcastResultType = cloneOrReplace( - firstBroadcastOrSplat->getOperand(0).getType(), resultElemType); + Type unbroadcastResultType = + cloneOrReplace(splatSource.getType(), resultElemType); // Make sure that all operands are broadcast from identically-shaped types: // * scalar (`vector.broadcast` + `vector.splat`), or // * vector (`vector.broadcast`). // Otherwise the re-ordering wouldn't be safe. - if (!llvm::all_of(op->getOperands(), [&unbroadcastResultType](Value val) { - if (auto bcastOp = val.getDefiningOp()) - return haveSameShapeAndScaling(bcastOp.getOperand().getType(), - unbroadcastResultType); - if (auto splatOp = val.getDefiningOp()) - return haveSameShapeAndScaling(splatOp.getOperand().getType(), - unbroadcastResultType); + if (!llvm::all_of(op->getOperands(), [splatSource](Value val) { + if (auto source = getBroadcastLikeSource(val)) + return haveSameShapeAndScaling(source.getType(), + splatSource.getType()); SplatElementsAttr splatConst; return matchPattern(val, m_Constant(&splatConst)); })) { - return failure(); + return rewriter.notifyMatchFailure( + op, + "not all operands are constants or broadcasts from the same type"); } // Collect the source values before broadcasting @@ -1287,15 +1302,17 @@ class StoreOpFromSplatOrBroadcast final return rewriter.notifyMatchFailure( op, "only 1-element vectors are supported"); - Operation *splat = op.getValueToStore().getDefiningOp(); - if (!isa_and_present(splat)) - return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast"); + Value toStore = op.getValueToStore(); + Value source = getBroadcastLikeSource(toStore); + if (!source) + return rewriter.notifyMatchFailure( + op, "value to store is not from a broadcast"); // Checking for single use so we can remove splat. + Operation *splat = toStore.getDefiningOp(); if (!splat->hasOneUse()) return rewriter.notifyMatchFailure(op, "expected single op use"); - Value source = splat->getOperand(0); Value base = op.getBase(); ValueRange indices = op.getIndices(); @@ -1345,13 +1362,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, // Add in an offset if requested. if (off) { Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); - Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o); + Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o); indices = arith::AddIOp::create(rewriter, loc, ov, indices); } // Construct the vector comparison. Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); Value bounds = - vector::SplatOp::create(rewriter, loc, indices.getType(), bound); + vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound); return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, indices, bounds); } diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir index 8e167a520260f..d5e344393b217 100644 --- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func @broadcast_vec1d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32> // CHECK: return %[[T0]] : vector<2xf32> func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { @@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { // CHECK-LABEL: func @broadcast_vec2d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32> // CHECK: return %[[T0]] : vector<2x3xf32> func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { @@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { // CHECK-LABEL: func @broadcast_vec3d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32> // CHECK: return %[[T0]] : vector<2x3x4xf32> func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { @@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3 // CHECK-LABEL: func @broadcast_stretch // CHECK-SAME: %[[A:.*0]]: vector<1xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32> // CHECK: return %[[T1]] : vector<4xf32> func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { @@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> // CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> // CHECK: %[[U0:.*]] = ub.poison : vector<4x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32> +// CHECK: %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32> // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32> +// CHECK: %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32> // CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> // CHECK: return %[[T15]] : vector<4x3xf32> diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir index 059d955f77313..5a8125ed67173 100644 --- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir @@ -5,11 +5,11 @@ // CHECK-SAME: %[[B:.*1]]: vector<3xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32> // CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> // CHECK: return %[[T7]] : vector<2x3xf32> @@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>, // CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32> // CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> +// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32> // CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32> // CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> @@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>, // CHECK-SAME: %[[B:.*1]]: vector<3xi32> // CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32> // CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32> +// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32> // CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> // CHECK: return %[[T7]] : vector<2x3xi32> @@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>, // CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> // CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32> // CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> // CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> // CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32> -// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> +// CHECK: %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32> // CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32> // CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> // CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> @@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>, // CHECK-LABEL: func @axpy_fp( // CHECK-SAME: %[[A:.*0]]: vector<16xf32>, // CHECK-SAME: %[[B:.*1]]: f32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32> // CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32> // CHECK: return %[[T1]] : vector<16xf32> func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { @@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { // CHECK-SAME: %[[A:.*0]]: vector<16xf32>, // CHECK-SAME: %[[B:.*1]]: f32, // CHECK-SAME: %[[C:.*2]]: vector<16xf32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32> // CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> // CHECK: return %[[T1]] : vector<16xf32> func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { @@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32> // CHECK-LABEL: func @axpy_int( // CHECK-SAME: %[[A:.*0]]: vector<16xi32>, // CHECK-SAME: %[[B:.*1]]: i32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32> // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> // CHECK: return %[[T1]] : vector<16xi32> func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { @@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { // CHECK-SAME: %[[A:.*0]]: vector<16xi32>, // CHECK-SAME: %[[B:.*1]]: i32, // CHECK-SAME: %[[C:.*2]]: vector<16xi32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32> // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> // CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32> // CHECK: return %[[T2]] : vector<16xi32>