Skip to content
Draft
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32",
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
"TRITON_INTEL_ENABLE_INSTR_SCHED",
"TRITON_INTEL_LOWER_DOT_TO_LOOP",
"TRITON_INTEL_FAST_MATH",
"TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT",
"TRITON_INTEL_REDUCE_TRANSPOSE",
Expand Down
2 changes: 2 additions & 0 deletions third_party/intel/lib/Analysis/DPAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ DPASAnalysis::DPASAnalysis(Operation *root) {

DPASAnalysis::Result
DPASAnalysis::canUseDPAS(FunctionOpInterface funcOp) const {
return Result::False;
if (funcToDotMap.empty() || dotToDPASEngineMap.empty())
return Result::False;

Expand Down Expand Up @@ -78,6 +79,7 @@ DPASAnalysis::canUseDPAS(FunctionOpInterface funcOp) const {
: Result::False;
}

return Result::False;
return (threadsPerWarp == minSGSize) ? Result::True : Result::False;
}

Expand Down
2 changes: 2 additions & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ add_triton_library(TritonIntelGPUToLLVM
ControlFlowOpToLLVM.cpp
ConvertLayoutOpToLLVM.cpp
DotOpToLLVM/DPAS.cpp
DotOpToLLVM/FMA.cpp
DotOpToLLVM/FMADotUtility.cpp
DotOpToLLVM.cpp
ElementwiseOpToLLVM.cpp
Fp4ToFpOpToLLVM.cpp
Expand Down
21 changes: 12 additions & 9 deletions third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
#include "PatternTritonGPUOpToLLVM.h"

using namespace mlir;
using namespace mlir::triton;

using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::intel::DpasEncodingAttr;

namespace fma_details {
LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor,
namespace mlir::triton::gpu::intel {
LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);

LogicalResult convertDPAS(DotOp op, DotOp::Adaptor adaptor,
TritonIntelGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);
} // namespace fma_details
} // namespace mlir::triton::gpu::intel

namespace {
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::DotOp>::ConvertTritonGPUOpToLLVMPattern;
using ConvertTritonGPUOpToLLVMPattern::ConvertTritonGPUOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
Expand All @@ -33,13 +35,14 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {

if (!isOuter && isa<DpasEncodingAttr>(
cast<RankedTensorType>(D.getType()).getEncoding())) {
return fma_details::convertDPAS(op, adaptor, getTypeConverter(),
rewriter);
return triton::gpu::intel::convertDPAS(op, adaptor, getTypeConverter(),
rewriter);
}

if (isa<BlockedEncodingAttr>(
cast<RankedTensorType>(D.getType()).getEncoding()))
return convertFMADot(op, adaptor, getTypeConverter(), rewriter);
return triton::gpu::intel::convertFMADot(op, adaptor, getTypeConverter(),
rewriter);

llvm::report_fatal_error(
"Unsupported DotOp found when converting TritonGPU to LLVM.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ class DotOpDPASConversionHelper {

} // namespace

namespace fma_details {
namespace mlir::triton::gpu::intel {

LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonIntelGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) {
Expand Down Expand Up @@ -441,4 +442,5 @@ LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor,

return helper.convertDot(op, adaptor);
}
} // namespace fma_details

} // namespace mlir::triton::gpu::intel
66 changes: 66 additions & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::triton;
using namespace ::mlir::triton::gpu;

namespace {
class GenericFMAVectorMultiplier : public FMAVectorMultiplier {
OpBuilder &builder;
Location loc;

public:
GenericFMAVectorMultiplier(OpBuilder &builder, Location loc)
: builder(builder), loc(loc) {}

Value multiplyVectors(ArrayRef<Value> a, ArrayRef<Value> b,
Value c) override {
auto K = a.size();
assert(b.size() == K);
Value accum = c;
Type tgtTy = accum.getType();
for (auto it = llvm::zip(a, b).begin(); it != llvm::zip(a, b).end(); ++it) {
const auto &aElem = std::get<0>(*it);
const auto &bElem = std::get<1>(*it);

assert(aElem.getType() == tgtTy);
assert(bElem.getType() == tgtTy);

// to avoid: 'llvm.intr.fmuladd' op operand #0 must be floating point LLVM
// type or LLVM dialect-compatible vector of floating point LLVM type, but
// got 'i32'
llvm::TypeSwitch<Type>(tgtTy)
.Case<FloatType>([&](auto) {
accum = builder.create<LLVM::FMulAddOp>(loc, aElem, bElem, accum);
})
.Case<IntegerType>([&](auto) {
accum = builder.create<LLVM::AddOp>(
loc, builder.create<LLVM::MulOp>(loc, aElem, bElem), accum);
});
}
return accum;
}
};

} // namespace

namespace mlir::triton::gpu::intel {

LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FMAVectorMultiplier &multiplier);

LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) {
auto *ctx = rewriter.getContext();
auto loc = op.getLoc();
GenericFMAVectorMultiplier multiplier(rewriter, loc);
return intel::parametricConvertFMADot(op, adaptor, typeConverter, rewriter,
multiplier);
}

} // namespace mlir::triton::gpu::intel
Loading
Loading