From 14b732e98b09f42c017952cacf3a9f90e0f05369 Mon Sep 17 00:00:00 2001 From: "Lu,Chengjun" Date: Thu, 3 Apr 2025 13:51:12 +0000 Subject: [PATCH] Lower the horizontal batched reduce of matrix with optimal inline VISA. Signed-off-by: Lu,Chengjun --- .../TritonGPUToLLVM/TargetInfoBase.h | 6 + include/triton/Tools/Sys/GetEnv.hpp | 1 + python/test/unit/intel/test_simd_reduce.py | 150 ++++++++++++++++ .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.h | 7 + .../TritonIntelGPUToLLVM/XeAsmFormat.h | 18 ++ .../TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp | 13 ++ .../lib/TritonIntelGPUToLLVM/TargetInfo.cpp | 164 ++++++++++++++++++ .../lib/TritonIntelGPUToLLVM/TargetInfo.h | 6 + .../lib/TritonIntelGPUToLLVM/XeAsmFormat.cpp | 91 +++++++++- .../TritonIntelGPUToLLVM/XeAsmFormatTest.cpp | 16 ++ .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.h | 7 + 11 files changed, 478 insertions(+), 1 deletion(-) create mode 100644 python/test/unit/intel/test_simd_reduce.py diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 1a8bff089d..00344b98a4 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -66,6 +66,12 @@ class TargetInfoBase { unsigned numLaneToReduce, unsigned interleave) const = 0; + virtual bool + warpBatchReduce(RewriterBase &rewriter, Location loc, + std::map, SmallVector> &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const = 0; + virtual std::string getMulhiFuncName(Type resultElementTy) const = 0; // Emits LLVM code with |rewriter| to print a message following the given // format from the device. |formatStrStart| is the pointer to the start of diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index fa299fb36b..8154d95af1 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -54,6 +54,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_INTEL_FAST_MATH", "TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT", "TRITON_INTEL_REDUCE_TRANSPOSE", + "TRITON_INTEL_ENABLE_SIMD_REDUCE", // clang-format on }; diff --git a/python/test/unit/intel/test_simd_reduce.py b/python/test/unit/intel/test_simd_reduce.py new file mode 100644 index 0000000000..69dc237445 --- /dev/null +++ b/python/test/unit/intel/test_simd_reduce.py @@ -0,0 +1,150 @@ +import os +import re +import numpy as np +from numpy.random import RandomState +import pytest +import torch +import pathlib + +import triton +from triton._internal_testing import numpy_random, to_numpy + +MIN_GROUP_SIZE = torch.xpu.get_device_capability()['sub_group_sizes'][0] + +os.environ["TRITON_INTEL_ENABLE_SIMD_REDUCE"] = "1" + + +class DpasLayout: + + def __init__(self, repeatCount, systolic_depth, execution_size, ops_per_chan, threads_per_warp, warps_per_cta, + rep_cluster): + self.repeatCount = repeatCount + self.systolic_depth = systolic_depth + self.execution_size = execution_size + self.ops_per_chan = ops_per_chan + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.rep_cluster = rep_cluster + + def __str__(self): + return f"#ttig.dpas<{{repeatCount={self.repeatCount}, systolicDepth={self.systolic_depth}, executionSize = {self.execution_size}, opsPerChan = {self.ops_per_chan}, threadsPerWarp = {self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, repCluster={self.rep_cluster}}}>" + + +class BlockedLayout: + + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1], + cta_split_num=[1, 1], cta_order=[0, 1]): + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +layouts = [ + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=16, + warps_per_cta=[4, 1], rep_cluster=[1, 1]), + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16, + warps_per_cta=[4, 1], rep_cluster=[1, 1]), + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16, + warps_per_cta=[2, 2], rep_cluster=[1, 1]), + BlockedLayout([1, 1], [1, 16], [2, 2], [1, 0]) +] + +if MIN_GROUP_SIZE == 16: + # Add threads_per_warp=32 cases. + layouts += [ + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=32, + warps_per_cta=[2, 2], rep_cluster=[1, 1]), + BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]) + ] + + +def warps_per_cta(layout, shape): + return layout.warps_per_cta + + +GPU_DIALECT = "ttg" + + +@pytest.mark.parametrize("M, N", [[8, 64], [16, 64], [32, 64], [64, 64], [128, 64]]) +@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("dtype_str", ["float32", "float16"]) +@pytest.mark.parametrize("reduce_op", ["sum", "max"]) +def test_horizontal_simd_reduce(M, N, src_layout, dtype_str, reduce_op, device, tmp_path: pathlib.Path): + ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] + arith_op = { + "max": {"int32": "arith.maxsi", "float32": "arith.maxnumf", "float16": "arith.maxnumf"}, # + "sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"} + }[reduce_op][dtype_str] + numpy_op = {"max": np.max, "sum": np.sum}[reduce_op] + rdims_1d = f"{M}" + rdims_2d = f"{M}x1" + store_range = "%1" + warps = src_layout.warps_per_cta + threads_per_warp = int(np.prod(src_layout.threads_per_warp)) + num_warps = int(np.prod(warps)) + blocked = BlockedLayout([1, 1], [16, threads_per_warp // 16], [4, num_warps // 4], [0, 1], [1, 1], [1, 1], [0, 1]) + one_d_layout = BlockedLayout([1], [threads_per_warp], [num_warps], [0], [1], [1], [0]) + + ir = f""" + #blocked = {blocked} + #src = {src_layout} + #one_d_layout = {one_d_layout} + module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {threads_per_warp} : i32, "ttig.min_sg_size" = {MIN_GROUP_SIZE} }} {{ + tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %2 = tt.splat %arg1 : i32 -> tensor<{M}x1xi32, #blocked> + %3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked> + %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked> + %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> + %7 = tt.expand_dims %6 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %9 = tt.broadcast %7 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src> + %13 = "tt.reduce"(%12) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = 1 : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %14 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + %15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked> + %16 = {GPU_DIALECT}.convert_layout %13 : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> + %17 = tt.expand_dims %16 {{axis = 1 : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{rdims_2d}x{ty}, #blocked> + tt.store %15, %17 : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + tt.return + }} + }} + """ + + temp_file = tmp_path / "test_reduce_layouts.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) + z_shape = (M, 1) + z = np.zeros(z_shape).astype(dtype_str) + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + kernel[(1, 1, 1)](x_tri, x_tri.stride(0), z_tri) + z_ref = numpy_op(x, axis=1, keepdims=True) + + if dtype_str == 'float16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + + llir = kernel.asm['llir'] + assert re.search(r'call .* asm', llir), 'no inline visa in llir' # inline visa is used diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 9777ce0ec7..7d19c3bbac 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -56,6 +56,13 @@ class TargetInfo : public mlir::triton::TargetInfoBase { triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const override; + bool warpBatchReduce(RewriterBase &rewriter, Location loc, + std::map, SmallVector> &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const override { + return false; + }; + std::string getMulhiFuncName(Type resultElementTy) const override; void printf(RewriterBase &rewriter, Value formatStrStart, diff --git a/third_party/intel/include/TritonIntelGPUToLLVM/XeAsmFormat.h b/third_party/intel/include/TritonIntelGPUToLLVM/XeAsmFormat.h index 6e89809caa..69904090f8 100644 --- a/third_party/intel/include/TritonIntelGPUToLLVM/XeAsmFormat.h +++ b/third_party/intel/include/TritonIntelGPUToLLVM/XeAsmFormat.h @@ -315,6 +315,24 @@ struct XeInstrExecution { bool onlyAttachMLIRArgs{}; }; +enum XeArch { + Xe = 0, + Xe2 = 1, + Xe3 = 2, +}; + +struct XeVISAInstr : public XeInstrBase { + using XeInstrBase::XeInstrBase; + + static std::optional getTypeName(Type scalarTy); + static unsigned getGRFSizeInBytes(XeArch arch); + static unsigned getExecMaskLaneNum(XeArch arch); +}; + +std::string simdReduceAsm(std::string binOp, unsigned warpSize, + unsigned numLaneToReduce, unsigned accSize, + Type elemTy, XeArch arch); + } // namespace triton } // namespace mlir diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp index 08a8b71e46..032df033a2 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp @@ -1,6 +1,7 @@ #include "PatternTritonGPUOpToLLVM.h" #include "lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include using namespace mlir; using namespace mlir::triton; @@ -176,6 +177,18 @@ struct ReduceOpConversion unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData(); unsigned threadOffsetOnReductionAxis = helper.getThreadOffsetOnReductionAxis(); + + bool simdReduce = + triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_SIMD_REDUCE"); + + if (simdReduce) { + auto ret = targetInfo.warpBatchReduce(rewriter, op.getLoc(), accs, op, + sizeIntraWarps, + threadOffsetOnReductionAxis); + if (ret) + return; + } + for (auto it : accs) { const SmallVector &key = it.first; SmallVector &acc = accs[key]; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp index 512b6bd365..496112414b 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp @@ -7,6 +7,10 @@ //===----------------------------------------------------------------------===// #include "TargetInfo.h" +#include "intel/include/TritonIntelGPUToLLVM/XeAsmFormat.h" +#include +#include + #include "Dialect/TritonIntelGPU/IR/Utils.h" #include "SPIRVTargetInfo.h" #include "Utility.h" @@ -15,6 +19,15 @@ using namespace mlir; namespace mlir::triton::intel { +struct XeSIMDReduceInstr : public XeVISAInstr { + + XeSIMDReduceInstr(XeBuilder *builder, std::string binOp, unsigned warpSize, + unsigned numLaneToReduce, unsigned accSize, Type elemTy, + XeArch arch) + : XeVISAInstr(builder, simdReduceAsm(binOp, warpSize, numLaneToReduce, + accSize, elemTy, arch)) {}; +}; + bool TargetInfo::supportMaximumMinimum() const { return false; } Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, Value cmp) const { @@ -106,6 +119,157 @@ Value TargetInfo::programId(RewriterBase &rewriter, Location loc, return rewriter.create(loc, i32_ty, blockId); } +static SmallVector> splitInBatches(ArrayRef srcValues, + size_t batchSize) { + SmallVector> batches; + for (; !srcValues.empty(); srcValues = srcValues.drop_front(batchSize)) + batches.push_back(srcValues.take_front(batchSize)); + return batches; +} + +bool TargetInfo::warpBatchReduce( + RewriterBase &rewriter, Location loc, + std::map, SmallVector> &acc, + triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { + // No horizontal reduce required. + if (numLaneToReduce == 1) + return false; + // Horizontal reduce with interleave stride not supported. + if (interleave > 1) + return false; + // Check if it is a simple reduce operation supported by + // TritonGEN::SubGroupReduceOp. + if (op.getNumOperands() != 1 || op.getNumResults() != 1) + return false; + Region &combineOp = op.getCombineOp(); + if (combineOp.getBlocks().size() > 1) + return false; + Block &block = *combineOp.begin(); + Operation *yield = block.getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) + return false; + if (reduceOp->getOperand(0) != block.getArgument(0) || + reduceOp->getOperand(1) != block.getArgument(1)) + return false; + + auto mod = op->getParentOfType(); + unsigned warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + + // TODO: support clustered reduce. + // if (numLaneToReduce != warpSize) + // return false; + + if (!isSupportedWarpReduceOp(reduceOp, numLaneToReduce, warpSize)) + return false; + + unsigned minSGSize = + mod->getAttrOfType( + gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName()) + .getInt(); + + if (isa(reduceOp)) { + // So we have to align the number of simd reduce results to the warp size. + unsigned accSize = acc.size(); + if (accSize == 1) + return false; + if ((accSize - 1) & accSize) + return false; + unsigned numResultPerRow = warpSize / numLaneToReduce; + accSize = std::min(accSize, warpSize / numResultPerRow); + Type elemType = acc.begin()->second[0].getType(); + VectorType reduceTy = vec_ty(elemType, accSize); + + // Group the acc in batch. + SmallVector inputAccs; + for (auto it : acc) { + const SmallVector &key = it.first; + SmallVector &val = acc[key]; + assert(val.size() == 1 && "acc size has to be 1 for ungrouped input"); + inputAccs.push_back(val[0]); + } + SmallVector> resultAccs(inputAccs.size() / accSize); + + std::string batchedHorizontalReduce; + // TODO: support all possible reduction modes + TypeSwitch(reduceOp) + .Case([&](auto) { batchedHorizontalReduce = "add"; }) + .Case([&](auto) { batchedHorizontalReduce = "max"; }) + .Default( + [&](auto) { llvm_unreachable("Unhandled batched reduce kind"); }); + + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto laneId = getLaneId(rewriter, loc); + auto reducePartial = b.udiv(laneId, b.i32_val(numLaneToReduce)); + llvm::transform( + splitInBatches(inputAccs, accSize), std::begin(resultAccs), + [&](ArrayRef inputs) { + auto inputRange = llvm::enumerate(inputs); + Value batchedReduceVal = std::accumulate( + std::begin(inputRange), std::end(inputRange), + rewriter.create(loc, reduceTy).getRes(), + [reduceTy, loc, &rewriter](Value acc, auto entry) -> Value { + auto [index, src] = entry; + auto b = TritonLLVMOpBuilder(loc, rewriter); + return b.insert_element(reduceTy, acc, src, b.i32_val(index)); + }); + XeBuilder xeBuilder; + XeSIMDReduceInstr &bReduceOp = *xeBuilder.create( + batchedHorizontalReduce, warpSize, numLaneToReduce, accSize, + elemType, minSGSize == 8 ? Xe : Xe2); + // The VISA inline asm doesn't support uniform result type. "=rw.u" + // auto res = vISABuilder.newOperand("=rw.u"); + XeBuilder::Operand *res = xeBuilder.newOperand("=rw"); + XeBuilder::Operand *in = xeBuilder.newOperand(batchedReduceVal, "rw"); + bReduceOp({res, in}, /*onlyAttachMLIRArgs=*/true); + // Type resultTy = reduceTy.getElementType(); + Value result = xeBuilder.launch(rewriter, loc, elemType, false); + SmallVector ret(accSize); + for (unsigned j = 0; j < accSize; ++j) { + if (numResultPerRow > 1) { + ret[j] = LLVM::intel::shuffleIdx( + loc, rewriter, result, + b.add(b.i32_val(j * numResultPerRow), reducePartial)); + } else { + ret[j] = LLVM::intel::shuffleIdx(loc, rewriter, result, j); + } + } + return ret; + }); + + unsigned grouped_iter = 0; + for (unsigned i = 0; i < resultAccs.size(); ++i) { + // The output of the inline vISA has to be the non-uniform value. + // Have to shuffle the result to get the reduce value. + SmallVector ret = resultAccs[i]; + for (unsigned j = 0; j < accSize; ++j) { + inputAccs[grouped_iter++] = ret[j]; + } + } + grouped_iter = 0; + for (auto it : acc) { + const SmallVector &key = it.first; + SmallVector &val = acc[key]; + val[0] = inputAccs[grouped_iter++]; + } +#if 0 + auto res = vISABuilder.newOperand("=rw.u"); + auto in = vISABuilder.newOperand(batchedReduceVal, "rw"); + bReduceOp({res, in}, /*onlyAttachMLIRArgs=*/true); + Value ret = vISABuilder.launch(rewriter, loc, reduceTy, false); + Type resultTy = reduceTy.getElementType(); + for (unsigned i = 0; i < grouped_accs.size(); ++i) { + grouped_accs[i] = b.extract_element(resultTy, ret, b.i32_val(i)); + } +#endif + + return true; + } + + return false; +} + bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h index f63c55c7f2..f6875d729b 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h @@ -50,6 +50,11 @@ class TargetInfo : public mlir::triton::TargetInfoBase { Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, ProgramIDDim axis) const override; + bool warpBatchReduce(RewriterBase &rewriter, Location loc, + std::map, SmallVector> &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const override; + bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const override; @@ -92,5 +97,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase { }; std::unique_ptr createTargetInfo(ModuleOp mod); + } // namespace mlir::triton::intel #endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOINTEL_H diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/XeAsmFormat.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/XeAsmFormat.cpp index 7ffa9b9c31..57e36979f4 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/XeAsmFormat.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/XeAsmFormat.cpp @@ -1,8 +1,9 @@ #include "intel/include/TritonIntelGPUToLLVM/XeAsmFormat.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Transforms/DialectConversion.h" #include "triton/Conversion/TritonGPUToLLVM/AsmFormat.h" #include "llvm/Support/raw_ostream.h" +#include +#include #include namespace mlir { @@ -231,5 +232,93 @@ XeInstr &XeInstr::b(int width) { return *this; } +std::optional XeVISAInstr::getTypeName(Type scalarTy) { + std::string typeSyntax; + TypeSwitch(scalarTy) + .Case([&](auto) { typeSyntax = "f"; }) + .Case([&](auto) { typeSyntax = "hf"; }) + .Default([&](auto) { typeSyntax = ""; }); + + if (!typeSyntax.empty()) + return typeSyntax; + return std::nullopt; +} + +unsigned XeVISAInstr::getGRFSizeInBytes(XeArch arch) { + switch (arch) { + case Xe: + return 8 * 4; + case Xe2: + case Xe3: + default: + return 16 * 4; + } +} + +std::string simdReduceAsm(std::string binOp, unsigned warpSize, + unsigned numLaneToReduce, unsigned accSize, + Type elemTy, XeArch arch) { + + // TODO: implement more variants. + assert(arch != Xe || + warpSize == 16 && "only suppor warpSize=16 for on Xe arch fow now"); + unsigned numElem = accSize * warpSize; + unsigned tempResultSize = numElem / 2; + unsigned reducePart = warpSize / numLaneToReduce; + unsigned numResult = accSize * reducePart; + + auto typeSyntax = XeVISAInstr::getTypeName(elemTy); + if (!typeSyntax) + llvm_unreachable("Unsupported scalar type"); + + unsigned grfSizeInBytes = XeVISAInstr::getGRFSizeInBytes(arch); + unsigned grfElemsPerRow = + grfSizeInBytes / (elemTy.getIntOrFloatBitWidth() / 8); + + constexpr StringLiteral reduceBuff = R"({ + .decl temp_result v_type=G type={0} num_elts={1} align=GRF + )"; + + std::string simdAsm = + llvm::formatv(reduceBuff.data(), *typeSyntax, tempResultSize).str(); + + constexpr StringLiteral reduceBinOp = + R"({0} (M1_NM, {1}) {2}({3}, {4})<1> {5}({6}, {7})<{11};{12},1> {8}({9}, {10})<{11};{12},1> + )"; + + for (unsigned n = numElem, rowStride = numLaneToReduce, + colNum = numLaneToReduce / 2; + n > numResult; n >>= 1, rowStride >>= 1, colNum >>= 1) { + unsigned elemPerRow = std::min(n, warpSize); + unsigned rowNum = n / elemPerRow; + unsigned reduceNum = (rowNum + 1) / 2; + + for (unsigned r = 0; r < reduceNum; r++) { + unsigned dstOffset = r * elemPerRow; + unsigned dstOffsetM = dstOffset / grfElemsPerRow; + unsigned dstOffsetN = dstOffset % grfElemsPerRow; + unsigned srcAOffset = r * 2 * elemPerRow; + unsigned srcBOffset = srcAOffset + colNum; + unsigned srcAOffM = srcAOffset / grfElemsPerRow; + unsigned srcAOffN = srcAOffset % grfElemsPerRow; + unsigned srcBOffM = srcBOffset / grfElemsPerRow; + unsigned srcBOffN = srcBOffset % grfElemsPerRow; + std::string simdOps = + llvm::formatv( + reduceBinOp.data(), binOp, std::min(n / 2, warpSize), + /*dst*/ colNum == 1 ? "$0" : "temp_result", dstOffsetM, + dstOffsetN, + /*srcA*/ colNum == numLaneToReduce / 2 ? "$1" : "temp_result", + srcAOffM, srcAOffN, + /*srcB*/ colNum == numLaneToReduce / 2 ? "$1" : "temp_result", + srcBOffM, srcBOffN, rowStride, colNum) + .str(); + simdAsm += simdOps; + } + } + simdAsm += "}"; + return simdAsm; +} + } // namespace triton } // namespace mlir diff --git a/third_party/intel/unittest/Conversion/TritonIntelGPUToLLVM/XeAsmFormatTest.cpp b/third_party/intel/unittest/Conversion/TritonIntelGPUToLLVM/XeAsmFormatTest.cpp index 315ee9a690..d806200eba 100644 --- a/third_party/intel/unittest/Conversion/TritonIntelGPUToLLVM/XeAsmFormatTest.cpp +++ b/third_party/intel/unittest/Conversion/TritonIntelGPUToLLVM/XeAsmFormatTest.cpp @@ -78,6 +78,22 @@ TEST_F(XeAsmFormatTest, MultiLineXe) { EXPECT_EQ(values[1], v[2]); // $1 -> v[2] } +TEST_F(XeAsmFormatTest, XeSIMDReduce) { + auto assemble = mlir::triton::simdReduceAsm( + "add", 16, 16, 16, Float16Type::get(&ctx), XeArch::Xe); + llvm::outs() << "johnlu:" << assemble; + assemble = mlir::triton::simdReduceAsm("max", 16, 16, 16, + Float32Type::get(&ctx), XeArch::Xe); + llvm::outs() << "johnlu:" << assemble; + assemble = mlir::triton::simdReduceAsm("add", 32, 32, 32, + Float32Type::get(&ctx), XeArch::Xe2); + llvm::outs() << "johnlu:" << assemble; + assemble = mlir::triton::simdReduceAsm("max", 32, 32, 32, + Float32Type::get(&ctx), XeArch::Xe2); + llvm::outs() << "johnlu:" << assemble; + // mlir::triton::simdReduceAsm("max", 16, 16, IntegerType::get(&ctx, 32)); +} + } // namespace triton } // namespace mlir diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index 4aed67d151..18088fbba7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -49,6 +49,13 @@ class TargetInfo : public mlir::triton::TargetInfoBase { triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const override; + bool warpBatchReduce(RewriterBase &rewriter, Location loc, + std::map, SmallVector> &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const override { + return false; + }; + std::string getMulhiFuncName(Type resultElementTy) const override; void printf(RewriterBase &rewriter, Value formatStrStart,