diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index fa299fb36b..942f7f04aa 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -51,6 +51,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32", "TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM", "TRITON_INTEL_ENABLE_INSTR_SCHED", + "TRITON_INTEL_LOWER_DOT_TO_LOOP", "TRITON_INTEL_FAST_MATH", "TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT", "TRITON_INTEL_REDUCE_TRANSPOSE", diff --git a/third_party/intel/lib/Analysis/DPAS.cpp b/third_party/intel/lib/Analysis/DPAS.cpp index a20cc53777..6bf5a53331 100644 --- a/third_party/intel/lib/Analysis/DPAS.cpp +++ b/third_party/intel/lib/Analysis/DPAS.cpp @@ -41,6 +41,7 @@ DPASAnalysis::DPASAnalysis(Operation *root) { DPASAnalysis::Result DPASAnalysis::canUseDPAS(FunctionOpInterface funcOp) const { + return Result::False; if (funcToDotMap.empty() || dotToDPASEngineMap.empty()) return Result::False; @@ -78,6 +79,7 @@ DPASAnalysis::canUseDPAS(FunctionOpInterface funcOp) const { : Result::False; } + return Result::False; return (threadsPerWarp == minSGSize) ? Result::True : Result::False; } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt index c95ad95f43..c8f27bc8f3 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt @@ -5,6 +5,8 @@ add_triton_library(TritonIntelGPUToLLVM ControlFlowOpToLLVM.cpp ConvertLayoutOpToLLVM.cpp DotOpToLLVM/DPAS.cpp + DotOpToLLVM/FMA.cpp + DotOpToLLVM/FMADotUtility.cpp DotOpToLLVM.cpp ElementwiseOpToLLVM.cpp Fp4ToFpOpToLLVM.cpp diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp index 6283e68a31..d115aba134 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp @@ -1,21 +1,23 @@ #include "PatternTritonGPUOpToLLVM.h" using namespace mlir; -using namespace mlir::triton; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::intel::DpasEncodingAttr; -namespace fma_details { -LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor, +namespace mlir::triton::gpu::intel { +LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); + +LogicalResult convertDPAS(DotOp op, DotOp::Adaptor adaptor, TritonIntelGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter); -} // namespace fma_details +} // namespace mlir::triton::gpu::intel namespace { struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { - using ConvertTritonGPUOpToLLVMPattern< - triton::DotOp>::ConvertTritonGPUOpToLLVMPattern; + using ConvertTritonGPUOpToLLVMPattern::ConvertTritonGPUOpToLLVMPattern; LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, @@ -33,13 +35,14 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { if (!isOuter && isa( cast(D.getType()).getEncoding())) { - return fma_details::convertDPAS(op, adaptor, getTypeConverter(), - rewriter); + return triton::gpu::intel::convertDPAS(op, adaptor, getTypeConverter(), + rewriter); } if (isa( cast(D.getType()).getEncoding())) - return convertFMADot(op, adaptor, getTypeConverter(), rewriter); + return triton::gpu::intel::convertFMADot(op, adaptor, getTypeConverter(), + rewriter); llvm::report_fatal_error( "Unsupported DotOp found when converting TritonGPU to LLVM."); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp index 71f1a20b9f..188a1fde24 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp @@ -406,7 +406,8 @@ class DotOpDPASConversionHelper { } // namespace -namespace fma_details { +namespace mlir::triton::gpu::intel { + LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor, TritonIntelGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter) { @@ -441,4 +442,5 @@ LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor, return helper.convertDot(op, adaptor); } -} // namespace fma_details + +} // namespace mlir::triton::gpu::intel diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp new file mode 100644 index 0000000000..2a59b5d9a1 --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -0,0 +1,66 @@ +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace ::mlir::triton::gpu; + +namespace { +class GenericFMAVectorMultiplier : public FMAVectorMultiplier { + OpBuilder &builder; + Location loc; + +public: + GenericFMAVectorMultiplier(OpBuilder &builder, Location loc) + : builder(builder), loc(loc) {} + + Value multiplyVectors(ArrayRef a, ArrayRef b, + Value c) override { + auto K = a.size(); + assert(b.size() == K); + Value accum = c; + Type tgtTy = accum.getType(); + for (auto it = llvm::zip(a, b).begin(); it != llvm::zip(a, b).end(); ++it) { + const auto &aElem = std::get<0>(*it); + const auto &bElem = std::get<1>(*it); + + assert(aElem.getType() == tgtTy); + assert(bElem.getType() == tgtTy); + + // to avoid: 'llvm.intr.fmuladd' op operand #0 must be floating point LLVM + // type or LLVM dialect-compatible vector of floating point LLVM type, but + // got 'i32' + llvm::TypeSwitch(tgtTy) + .Case([&](auto) { + accum = builder.create(loc, aElem, bElem, accum); + }) + .Case([&](auto) { + accum = builder.create( + loc, builder.create(loc, aElem, bElem), accum); + }); + } + return accum; + } +}; + +} // namespace + +namespace mlir::triton::gpu::intel { + +LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier); + +LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + GenericFMAVectorMultiplier multiplier(rewriter, loc); + return intel::parametricConvertFMADot(op, adaptor, typeConverter, rewriter, + multiplier); +} + +} // namespace mlir::triton::gpu::intel diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp new file mode 100644 index 0000000000..650a9f0e75 --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -0,0 +1,438 @@ +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Value.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +namespace { + +/// OperandValueKey structure represents compile time part +/// of spatial coordinates of a value in a tensor. +/// +/// Every Value spatial coordinates(i.e. [batch;nonK;k]) in tensor can be +/// defined as: +/// +/// batch = (bRepIdx * CTABSize + bIdx) + (laneBCoord + warpBCoord) +/// nonK = (nonKRepIdx * CTANKSize + nonKIdx) + (laneNonKCoord + warpNonKCoord) +/// k = kIdx +/// +/// Where: +/// CTABSize, CTANKSize: constants; +/// laneBCoord, warpBCoord, laneNonKCoord, warpNonKCoord: runtime components; +/// bRepIdx, nonKRepIdx, bIdx, nonKIdx, kIdx: compile time components. +struct OperandValueKey { + unsigned bRepIdx, nonKRepIdx; + unsigned bIdx, nonKIdx, kIdx; + + bool operator==(const OperandValueKey &other) const { + return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx && + bIdx == other.bIdx && nonKIdx == other.nonKIdx && + kIdx == other.kIdx); + } + + void print() const { + llvm::errs() << "[" << bRepIdx << "," << nonKRepIdx << "," << bIdx << "," + << nonKIdx << "," << kIdx << "]"; + } +}; + +} // namespace + +template <> struct std::hash { + std::size_t operator()(const OperandValueKey &k) const { + return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx, + k.kIdx); + } +}; + +namespace { + +using ValueTableFMA = std::unordered_map; + +ValueTableFMA getValueTableFromStructFMA( + Value val, ArrayRef perRepShape, ArrayRef repetitions, + unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter, + Location loc, ArrayRef inRepOrder, ArrayRef repOrder) { + ValueTableFMA res; + auto elems = unpackLLElements(loc, val, rewriter); + assert(perRepShape.size() == 3); + auto numElemsRep = product(perRepShape); + assert(elems.size() == numElemsRep * product(repetitions)); + assert(kDim == 1 || kDim == 2); + assert(nonKDim == 1 || nonKDim == 2); + const unsigned bDim = 0; + + for (unsigned idx = 0; idx < elems.size(); ++idx) { + auto inRepLinearIdx = idx % numElemsRep; + auto repLinearIdx = idx / numElemsRep; + auto inRepSpatialIdx = + mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder); + auto repSpatialIdx = + mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder); + OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim], + inRepSpatialIdx[0], inRepSpatialIdx[nonKDim], + inRepSpatialIdx[kDim]}; + + llvm::errs() << "key: "; + key.print(); + llvm::errs() << "idx = " << idx << "\n"; + res[key] = elems[idx]; + } + return res; +} + +struct LoopInfo { + Block *header; + Block *body; + Block *end; +}; + +/// Creates an empty loop structure with a header, body, and end block. The +/// loop is initialized with an induction variable and an initial argument. +/// - Parameters: +/// - `iv`: The induction variable (already initialized to the lower bound). +/// - `ub`: The upper bound for the loop. +/// - `step`: The step for the induction variable. +/// - `initArg`: The initial argument passed to the loop. +/// - Returns a `LoopInfo` structure containing the header, body, and end +// blocks. +LoopInfo createEmptyLoop(Value iv, Value ub, Value step, Value initArg, + ConversionPatternRewriter &rewriter, Location loc) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + + // Create loop blocks. + Block *insertionBlock = rewriter.getInsertionBlock(); + Block *headerBlock = + rewriter.splitBlock(insertionBlock, rewriter.getInsertionPoint()); + Block *bodyBlock = rewriter.splitBlock(headerBlock, headerBlock->begin()); + Block *endBlock = rewriter.splitBlock(bodyBlock, bodyBlock->begin()); + + // Add arguments to blocks. + Type ivTy = iv.getType(); + Type initArgTy = initArg.getType(); + headerBlock->addArguments({ivTy, initArgTy}, {loc, loc}); + bodyBlock->addArguments({ivTy, initArgTy}, {loc, loc}); + endBlock->addArgument(initArgTy, loc); + + // Connect insertion block to header block. + rewriter.setInsertionPointToEnd(insertionBlock); + rewriter.create(loc, headerBlock, ValueRange{iv, initArg}); + + // Build header block. + rewriter.setInsertionPointToStart(headerBlock); + Value cond = b.icmp_slt(headerBlock->getArgument(0), ub); + rewriter.create(loc, cond, bodyBlock, + headerBlock->getArguments(), endBlock, + ValueRange{headerBlock->getArgument(1)}); + + // Build body block. + rewriter.setInsertionPointToStart(bodyBlock); + Value nextIV = b.add(bodyBlock->getArgument(0), step); + auto latch = rewriter.create( + loc, headerBlock, ValueRange{nextIV, bodyBlock->getArgument(1)}); + + // Look at example in: + // + auto noUnrollAttr = StringAttr::get(ctx, "llvm.loop.unroll.disable"); + auto namedAttr = rewriter.getNamedAttr("llvm.loop", noUnrollAttr); + auto arrayAttr = rewriter.getArrayAttr({noUnrollAttr}); + latch->setAttr("llvm.loop", arrayAttr); + + // Set insertion point to end block. + rewriter.setInsertionPointToStart(endBlock); + + return {headerBlock, bodyBlock, endBlock}; +} + +/// Initializes a variable to a given value and returns the loaded value. +/// - Parameters: +/// - `init`: The initial value to assign to the variable. +/// - Returns: The loaded value of the initialized variable. +Value createIV(Value init, ConversionPatternRewriter &rewriter, Location loc) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ptr = rewriter.create(loc, ptr_ty(rewriter.getContext()), + init.getType(), b.i32_val(1)); + rewriter.create(loc, init, ptr); + return rewriter.create(loc, init.getType(), ptr); +} + +} // namespace + +namespace mlir::triton::gpu::intel { + +LogicalResult genFMALoop(DotOp, ValueTableFMA &, ValueTableFMA &, + ArrayRef, ArrayRef, + ArrayRef, ArrayRef, + ArrayRef, const unsigned, const unsigned, + Type, ConversionPatternRewriter &, + FMAVectorMultiplier &); + +LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + auto A = op.getA(); + auto D = op.getResult(); + + auto aTensorTy = cast(A.getType()); + auto dTensorTy = cast(D.getType()); + + SmallVector aShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); + auto dShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); + + BlockedEncodingAttr dLayout = + cast(dTensorTy.getEncoding()); + // TODO process A and B operand separately + auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder()); + auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder()); + auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); + + Value llA = adaptor.getA(); + Value llB = adaptor.getB(); + + auto sizePerThread = getContigPerThread(dTensorTy); + auto numElemsPerThread = product(sizePerThread); + SmallVector shapePerCTATile; + for (auto [reg, thread, warp] : + llvm::zip(sizePerThread, dLayout.getThreadsPerWarp(), + dLayout.getWarpsPerCTA())) { + shapePerCTATile.push_back(reg * thread * warp); + } + shapePerCTATile = expandMatrixShapeWithBatch(ArrayRef(shapePerCTATile)); + sizePerThread = expandMatrixShapeWithBatch(ArrayRef(sizePerThread)); + + unsigned K = aShapePerCTA[2]; + + unsigned threadTileShape[3]; + unsigned repetitions[3]; + for (int i = 0; i < 3; ++i) { + repetitions[i] = + ceil(dShapePerCTA[i], static_cast(shapePerCTATile[i])); + } + + auto has = getValueTableFromStructFMA( + llA, {sizePerThread[0], sizePerThread[1], K}, + {repetitions[0], repetitions[1], 1}, + /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder); + auto hbs = getValueTableFromStructFMA( + llB, {sizePerThread[0], K, sizePerThread[2]}, + {repetitions[0], 1, repetitions[2]}, + /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder); + + SmallVector acc = cc; + + if (triton::tools::getBoolEnv("TRITON_INTEL_LOWER_DOT_TO_LOOP")) { + Type dType = typeConverter->convertType(dTensorTy); + return genFMALoop(op, has, hbs, acc, sizePerThread, repetitions, inRepOrder, + repOrder, numElemsPerThread, K, dType, rewriter, + multiplier); + } + + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) + for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) + for (unsigned b = 0; b < sizePerThread[0]; ++b) + for (unsigned m = 0; m < sizePerThread[1]; ++m) + for (unsigned n = 0; n < sizePerThread[2]; ++n) { + SmallVector multiDimAccumIdx = {b, m, n}; + unsigned linearInRepIdx = + LLVM::linearize(multiDimAccumIdx, sizePerThread, inRepOrder); + SmallVector multiDimRepIdx = {bRep, mRep, nRep}; + unsigned linearRepIdx = + LLVM::linearize(multiDimRepIdx, repetitions, repOrder); + unsigned linearAccumIdx = + linearInRepIdx + linearRepIdx * numElemsPerThread; + + SmallVector aOpVector; + SmallVector bOpVector; + + for (unsigned k = 0; k < K; ++k) { + aOpVector.push_back(has.at({bRep, mRep, b, m, k})); + bOpVector.push_back(hbs.at({bRep, nRep, b, n, k})); + } + + acc[linearAccumIdx] = multiplier.multiplyVectors( + aOpVector, bOpVector, acc[linearAccumIdx]); + } + + auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); + rewriter.replaceOp(op, res); + + return success(); +} + +LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, + ArrayRef acc, ArrayRef sizePerThread, + ArrayRef repetitions, + ArrayRef inRepOrder, + ArrayRef repOrder, + const unsigned numElemsPerThread, const unsigned K, + Type dType, ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier) { + MLIRContext *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + auto builder = TritonLLVMOpBuilder(loc, rewriter); + + // Copy structs into vector for operand A, B, C. + SmallVector aOpVector, bOpVector; + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) + for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) + for (unsigned b = 0; b < sizePerThread[0]; ++b) + for (unsigned m = 0; m < sizePerThread[1]; ++m) + for (unsigned n = 0; n < sizePerThread[2]; ++n) + for (unsigned k = 0; k < K; ++k) { + aOpVector.push_back(has.at({bRep, mRep, b, m, k})); + bOpVector.push_back(hbs.at({bRep, nRep, b, n, k})); + } + + Value vecA = packLLVector(loc, aOpVector, rewriter); + Value vecB = packLLVector(loc, bOpVector, rewriter); + Value vecC = packLLVector(loc, acc, rewriter); + + auto getFragment = [&](Value vec, unsigned bRep, unsigned mRep, unsigned nRep, + unsigned b, Value outerIV, Value innerIV, + Value outerUB, Value innerUB, unsigned K) { + SmallVector elems; + // TODO: compute idx by also using bRep, nRep, b. + // mRep * M * N * K + m * N * K + n * K + k + Value idx = builder.mul( + builder.mul(builder.mul(builder.i32_val(mRep), outerUB), innerUB), + builder.i32_val(K)); + idx = builder.add( + idx, builder.mul(builder.mul(outerIV, innerUB), builder.i32_val(K))); + idx = builder.add(idx, builder.mul(innerIV, builder.i32_val(K))); + + for (unsigned k = 0; k < K; ++k) { + idx = (k != 0) ? builder.add(idx, builder.i32_val(1)) : idx; + elems.push_back(builder.extract_element(vec, idx)); + } + return elems; + }; + + auto linearize = [&](ArrayRef multiDim, ArrayRef shape, + ArrayRef order) { + auto builder = TritonLLVMOpBuilder(loc, rewriter); + Value linear = builder.i32_val(0); + for (unsigned dim : llvm::reverse(order)) { + Value mul = builder.mul(linear, builder.i32_val(shape[dim])); + linear = builder.add(mul, multiDim[dim]); + } + return linear; + }; + + Value zero = builder.i32_val(0), one = builder.i32_val(1); + Type elemType = acc.front().getType(); + Value initArg = vecC; + + Value vecD; + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) { + for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) { + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) { + for (unsigned b = 0; b < sizePerThread[0]; ++b) { + // Generate the outer loop. + Value outerUB = builder.i32_val(sizePerThread[1]); + LoopInfo outerLoopInfo = + createEmptyLoop(createIV(zero, rewriter, loc), outerUB, one, + {initArg}, rewriter, loc); + Block *outerBody = outerLoopInfo.body; + Block *outerEnd = outerLoopInfo.end; + Value outerIV = outerBody->getArgument(0); + auto outerLatch = cast(outerBody->getTerminator()); + vecD = outerEnd->getArgument(0); + auto afterOuterLoop = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(outerLoopInfo.body); + + // Generate the inner loop. + Value innerUB = builder.i32_val(sizePerThread[2]); + Value initArg = + outerBody->getArgument(outerBody->getNumArguments() - 1); + LoopInfo innerLoopInfo = + createEmptyLoop(createIV(zero, rewriter, loc), innerUB, one, + initArg, rewriter, loc); + Block *innerBody = innerLoopInfo.body; + Block *innerEnd = innerLoopInfo.end; + Value innerIV = innerBody->getArgument(0); + rewriter.setInsertionPointToStart(innerLoopInfo.body); + + // Get the fragments for operands A and B. + SmallVector AElems = getFragment( + vecA, bRep, mRep, nRep, b, outerIV, innerIV, outerUB, innerUB, K); + SmallVector BElems = getFragment( + vecB, bRep, mRep, nRep, b, outerIV, innerIV, outerUB, innerUB, K); + + // Compute the index into the accumulator. + SmallVector multiDimAccumIdx{builder.i32_val(b), outerIV, + innerIV}; + Value linearInRepIdx = + linearize(multiDimAccumIdx, sizePerThread, inRepOrder); + SmallVector multiDimRepIdx = {bRep, mRep, nRep}; + unsigned linearRepIdx = + LLVM::linearize(multiDimRepIdx, repetitions, repOrder); + Value linearAccumIdx = + builder.add(linearInRepIdx, + builder.i32_val(linearRepIdx * numElemsPerThread)); + Value innerInitArg = + innerBody->getArgument(innerBody->getNumArguments() - 1); + + // Extract the element from accumulator. + Value CElem = builder.extract_element(innerInitArg, linearAccumIdx); + + // Perform the FMAs. + CElem = multiplier.multiplyVectors(AElems, BElems, CElem); + + // Store the result. + innerInitArg = + builder.insert_element(innerInitArg, CElem, linearAccumIdx); + rewriter.restoreInsertionPoint(afterOuterLoop); + + // Pass the result to the next inner loop iteration. + auto innerLatch = cast(innerBody->getTerminator()); + innerLatch->setOperand(innerLatch->getNumOperands() - 1, + innerInitArg); + + // Pass the result of the inner loop to the next outer loop iteration. + Value innerEndArg = innerEnd->getArgument(0); + outerLatch->setOperand(outerLatch->getNumOperands() - 1, innerEndArg); + } + initArg = vecD; + } + } + } + + // Create a loop to copy the final result into a struct. + Value UB = builder.i32_val(acc.size()); + auto structPtr = + rewriter.create(loc, ptr_ty(ctx), elemType, UB); + LoopInfo loopInfo = createEmptyLoop(createIV(zero, rewriter, loc), UB, one, + vecD, rewriter, loc); + Block *body = loopInfo.body; + Value iv = body->getArgument(0); + auto afterLoop = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(body); + Value val = builder.extract_element( + body->getArgument(body->getNumArguments() - 1), iv); + Value ptr = builder.gep(ptr_ty(ctx), val.getType(), structPtr, iv); + rewriter.create(loc, val, ptr); + + // Load the struct and replace the original op. + rewriter.restoreInsertionPoint(afterLoop); + auto loadVal = rewriter.create(loc, dType, structPtr); + rewriter.replaceOp(op, loadVal); + + return success(); +} + +} // namespace mlir::triton::gpu::intel