Skip to content

Commit f39300c

Browse files
committed
-Adds quant-dequant(f32->i8->f32) validation kernel generation and runtime test.
-Adds '-print-input' flag to print input arguments for visual inspection. -Refactored and updated the corresponding APIs.
1 parent c49808a commit f39300c

File tree

8 files changed

+177
-65
lines changed

8 files changed

+177
-65
lines changed

include/TPP/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,9 @@ def TppRunnerWrapper : Pass<"tpp-runner-wrapper", "ModuleOp">{
568568
Option<"printResult", "print", "bool",
569569
/*default=*/"false",
570570
"Print kernel results.">,
571+
Option<"printInput", "print-input", "bool",
572+
/*default=*/"false",
573+
"Print kernel inputs">,
571574
Option<"randomSplat", "random-splat", "bool",
572575
/*default=*/"false",
573576
"Replace splat dense tensors with random values.">,

include/TPP/Runner/MLIRBench.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ class MLIRBench {
109109
Value registerOnGpu(Value buf, MemRefType memRefTy);
110110

111111
public:
112+
/// Return kernelArgs
113+
llvm::SmallVector<Value> getKernelArgs() { return kernelArgs; }
112114
/// Creates context, builder
113115
MLIRBench(Operation *op, const MLIRBenchConfig &config);
114116

lib/TPP/Runner/TppRunnerWrapper.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,19 @@ struct TppRunnerWrapper
111111
return;
112112
}
113113

114+
// Print the kernel's input arguments by iterating through kernelArgs
115+
if (printInput) {
116+
for (auto arg : bench.getKernelArgs()) {
117+
if (auto shapedType = dyn_cast<ShapedType>(arg.getType())) {
118+
if (shapedType.getRank() == 1)
119+
continue;
120+
if (shapedType.hasStaticShape())
121+
if (failed(bench.printShapedType(arg)))
122+
return;
123+
}
124+
}
125+
}
126+
114127
// Either run once or run benchmarks
115128
if (numBenchLoops > 1) {
116129
if (benchWarmup) {

test/Integration/mlir-gen-matmul.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,14 @@
195195
// MXF32I8-QUANT: linalg.reduce {{.*}} dimensions = [0]
196196
// MXF32I8-QUANT: math.absf
197197
// MXF32I8-QUANT: arith.maximumf
198-
// MXF32I8-QUANT: linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel"]}
198+
// MXF32I8-QUANT: linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel"]}
199199
// MXF32I8-QUANT: llvm.intr.frexp
200200
// MXF32I8-QUANT: llvm.extractvalue
201201
// MXF32I8-QUANT: arith.constant 7
202202
// MXF32I8-QUANT: arith.subi
203203
// MXF32I8-QUANT: arith.subi
204204
// MXF32I8-QUANT: arith.sitofp
205+
// MXF32I8-QUANT: arith.sitofp
205206
// MXF32I8-QUANT: math.exp2
206207
// MXF32I8-QUANT: linalg.broadcast
207208
// MXF32I8-QUANT: linalg.mul

test/Integration/mlir-gen.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@
3535
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF
3636
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10,10 --tiles=2,2,2 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF
3737

38+
// Use two Kernels, one matmul and other quantize-dequantize. That is, the result of matmul would be quantized and then dequantized.
39+
// RUN: mlir-gen --kernel=const --seed=0 --float-type=f32 --batch=3 --layers=3,3 --identity | tpp-run -e entry -entry-point-result=void -print --splat-to-random --init-type normal -seed 123 > %t.1
40+
// RUN: mlir-gen --identity --kernel=const --seed=0 --float-type=mx-f32-i8 --batch=3 --layers=3,3 --quant-type=testquant | tpp-run -e entry -entry-point-result=void -print --splat-to-random --init-type normal -seed 123 > %t.2
41+
42+
// Comparison is done between the result of matmul and the result of quantize-dequantize kernels.
43+
// RUN: fpcmp -a 0.01 -r 0.01 %t.1 %t.2
44+
45+
3846
// Implements C = A*B, with A=1, B=ID
3947
// IDENTITY_CONST:( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
4048

tools/mlir-gen/MLIRGen.cpp

Lines changed: 136 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "MLIRGen.h"
2121
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22+
// #include "mlir/IR/Value.h"
2223
#include "llvm/Support/ErrorHandling.h"
2324
#include "llvm/Support/raw_ostream.h"
2425

@@ -154,6 +155,7 @@ MLIRGenerator::MLIRGenerator(StringRef outputOpKindStr, StringRef kernelStr,
154155
llvm::StringSwitch<std::optional<QuantizationType>>(quantizationTypeStr)
155156
.CaseLower("quantize", QuantizationType::Quant)
156157
.CaseLower("dequantize", QuantizationType::Dequant)
158+
.CaseLower("testquant", QuantizationType::QuantDequant)
157159
.Default(QuantizationType::None);
158160
quantType = *optQuantType;
159161

@@ -210,20 +212,68 @@ void MLIRGenerator::getKernelTypes(KernelArgs &args) {
210212
if (quantType == QuantizationType::Dequant)
211213
arg.weightScale.type = getShape({outputSize}, WEIGHT_SCALE);
212214
arg.bias.type = getShape({outputSize}, PACK_OUTPUT);
213-
arg.output.type = getShape({batch, outputSize}, PACK_OUTPUT);
215+
216+
// For QuantDequant, such as F32->i8->F32, we need an intermediate type to
217+
// hold the quantized value.
218+
if (quantType == QuantizationType::QuantDequant) {
219+
arg.intermediate.type = getShape({batch, outputSize}, PACK_INTERMEDIATE);
220+
arg.output.type = getShape({batch, outputSize}, PACK_INPUT);
221+
} else {
222+
arg.output.type = getShape({batch, outputSize}, PACK_OUTPUT);
223+
}
214224
args.push_back(arg);
215225

216226
// Update next input type with the output type of this layer
217227
currentType = arg.output.type;
218228
}
219229
}
220230

231+
// Creates a quantize op around the gemm output and subsequently dequantize it.
232+
// This is mainly to validate the quantization scheme.
233+
Value MLIRGenerator::testQuantDequant(LayerArgs &args, Value input) {
234+
SmallVector<Value> scalingFactors = computeScalingFactor(input);
235+
Value chain = quantizeGemm(args, input, scalingFactors[0]);
236+
Value reScaleFactor = scalingFactors[1];
237+
Type rescaleType = reScaleFactor.getType();
238+
auto castedOutput =
239+
builder.create<tensor::EmptyOp>(loc, rescaleType, ValueRange{});
240+
Value castedVal =
241+
builder
242+
.create<linalg::GenericOp>(
243+
loc, rescaleType, ValueRange{chain}, ValueRange{castedOutput},
244+
ArrayRef<AffineMap>{getMap(chain, MAP_PARALLEL),
245+
getMap(castedOutput, MAP_PARALLEL)},
246+
getIterators(MAP_PARALLEL),
247+
[&](OpBuilder &nestedBuilder, Location nestedLoc,
248+
ValueRange blockArgs) {
249+
auto arg0 = blockArgs[0];
250+
auto casted = nestedBuilder.create<arith::SIToFPOp>(
251+
loc, dataTypes[2], arg0);
252+
nestedBuilder.create<linalg::YieldOp>(loc, ValueRange{casted});
253+
})
254+
.getResult(0);
255+
castedVal = builder
256+
.create<linalg::MulOp>(loc, TypeRange{castedOutput.getType()},
257+
ValueRange{castedVal, reScaleFactor},
258+
ValueRange{castedOutput})
259+
.getResult(0);
260+
return castedVal;
261+
}
262+
221263
Value MLIRGenerator::createLayer(LayerArgs &args, bool hasMixedType) {
222264
OpBuilder::InsertionGuard guard(builder);
223265

224266
Value chain;
225267
chain = lowerMatmul(args, hasMixedType);
226268

269+
if (quantType == QuantizationType::QuantDequant)
270+
return testQuantDequant(args, chain);
271+
272+
if (quantType == QuantizationType::Quant) {
273+
SmallVector<Value> scalingFactors = computeScalingFactor(chain);
274+
chain = quantizeGemm(args, chain, scalingFactors[0]);
275+
}
276+
227277
if (quantType == QuantizationType::Dequant)
228278
chain = dequantizeGemm(args, chain);
229279

@@ -236,9 +286,6 @@ Value MLIRGenerator::createLayer(LayerArgs &args, bool hasMixedType) {
236286
chain = lowerNamedRelu(chain, args.output.value);
237287
}
238288

239-
if (quantType == QuantizationType::Quant)
240-
chain = quantizeGemm(args, chain);
241-
242289
// Last layer may output softmax
243290
if (args.index == layers.size() - 1) {
244291
if (outputOpKind == OutputOpKind::Generic) {
@@ -577,8 +624,7 @@ Value MLIRGenerator::lowerContract(Value input, Value weight, Value output) {
577624
return contract;
578625
}
579626

580-
Value MLIRGenerator::computeScalingFactor(MLIRContext *ctx, Value input,
581-
Value scale) {
627+
SmallVector<Value> MLIRGenerator::computeScalingFactor(Value input) {
582628
auto inputType = cast<ShapedType>(input.getType());
583629
assert(inputType.getRank() == 2 && "Input must be a 2D tensor");
584630

@@ -617,93 +663,116 @@ Value MLIRGenerator::computeScalingFactor(MLIRContext *ctx, Value input,
617663
// Compute the scaling factors (2^(-exponent)) from the absolute maximum
618664
// values.
619665
Value zeroVal = builder.create<arith::ConstantIntOp>(loc, 0, 32);
620-
Value zeroFloat =
621-
builder.create<arith::ConstantOp>(loc, builder.getF32FloatAttr(0.0f));
622-
Value channleScale =
666+
667+
// Create two output tensors for the two results
668+
context.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
669+
Value channelScale =
623670
builder.create<tensor::EmptyOp>(loc, reductionType, ValueRange{});
624-
Value filledchannleScale =
625-
builder.create<linalg::FillOp>(loc, zeroFloat, channleScale).getResult(0);
626-
Value frExp =
627-
builder
628-
.create<linalg::GenericOp>(
629-
loc, reductionType, ValueRange{absMax},
630-
ValueRange{filledchannleScale},
631-
ArrayRef<AffineMap>{getMap(absMax, MAP_PARALLEL),
632-
getMap(filledchannleScale, MAP_PARALLEL)},
633-
ArrayRef<utils::IteratorType>{utils::IteratorType::parallel},
634-
[&](OpBuilder &nestedBuilder, Location nestedLoc,
635-
ValueRange args) {
636-
Value frexpResult = LLVM::FractionExpOp::create(
637-
nestedBuilder, nestedLoc,
638-
LLVM::LLVMStructType::getLiteral(
639-
ctx, ArrayRef<Type>{elementType, builder.getI32Type()}),
640-
ValueRange{args[0]});
641-
Value exponent = LLVM::ExtractValueOp::create(
642-
nestedBuilder, nestedLoc,
643-
builder.getI32Type(), frexpResult, 1)
644-
.getResult();
645-
Value unbiased = nestedBuilder.create<arith::SubIOp>(
646-
nestedLoc, exponent,
647-
builder.create<arith::ConstantOp>(
648-
nestedLoc, builder.getI32IntegerAttr(7)));
649-
Value negExponent = nestedBuilder.create<arith::SubIOp>(
650-
nestedLoc, zeroVal, unbiased);
651-
auto tchannleScale =
652-
nestedBuilder
653-
.create<math::Exp2Op>(
654-
nestedLoc, nestedBuilder.create<arith::SIToFPOp>(
655-
nestedLoc, elementType, negExponent))
656-
->getResult(0);
657-
nestedBuilder.create<linalg::YieldOp>(nestedLoc, tchannleScale);
658-
})
659-
.getResult(0);
671+
Value channelReScale =
672+
builder.create<tensor::EmptyOp>(loc, reductionType, ValueRange{});
673+
674+
auto frExp = builder.create<linalg::GenericOp>(
675+
loc,
676+
TypeRange{reductionType, reductionType}, // Specify multiple result types
677+
ValueRange{absMax}, ValueRange{channelScale, channelReScale},
678+
ArrayRef<AffineMap>{getMap(absMax, MAP_PARALLEL),
679+
getMap(channelScale, MAP_PARALLEL),
680+
getMap(channelReScale, MAP_PARALLEL)},
681+
ArrayRef<utils::IteratorType>{utils::IteratorType::parallel},
682+
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
683+
Value frexpResult = LLVM::FractionExpOp::create(
684+
nestedBuilder, nestedLoc,
685+
LLVM::LLVMStructType::getLiteral(
686+
&context, ArrayRef<Type>{elementType, builder.getI32Type()}),
687+
ValueRange{args[0]});
688+
Value exponent =
689+
LLVM::ExtractValueOp::create(nestedBuilder, nestedLoc,
690+
builder.getI32Type(), frexpResult, 1)
691+
.getResult();
692+
Value unbiased = nestedBuilder.create<arith::SubIOp>(
693+
nestedLoc, exponent,
694+
builder.create<arith::ConstantOp>(nestedLoc,
695+
builder.getI32IntegerAttr(7)));
696+
Value negExponent =
697+
nestedBuilder.create<arith::SubIOp>(nestedLoc, zeroVal, unbiased);
698+
auto tchannleReScale =
699+
nestedBuilder
700+
.create<math::Exp2Op>(nestedLoc,
701+
nestedBuilder.create<arith::SIToFPOp>(
702+
nestedLoc, elementType, unbiased))
703+
->getResult(0);
704+
auto tchannleScale =
705+
nestedBuilder
706+
.create<math::Exp2Op>(nestedLoc,
707+
nestedBuilder.create<arith::SIToFPOp>(
708+
nestedLoc, elementType, negExponent))
709+
->getResult(0);
710+
nestedBuilder.create<linalg::YieldOp>(
711+
nestedLoc, ValueRange{tchannleScale, tchannleReScale});
712+
});
713+
714+
SmallVector<Value> frExpVec;
715+
frExpVec.push_back(frExp.getResults()[0]);
716+
frExpVec.push_back(frExp.getResults()[1]);
660717

718+
SmallVector<Value> scalingFactors;
719+
Value scalingFactor =
720+
builder.create<tensor::EmptyOp>(loc, inputType, ValueRange{});
661721
Value filledTensor =
662-
builder.create<linalg::FillOp>(loc, initValue, scale).getResult(0);
722+
builder.create<linalg::FillOp>(loc, initValue, scalingFactor)
723+
.getResult(0);
663724
// Broadcast to match output shape
664725
auto broadcastScaleRes =
665726
builder
666-
.create<linalg::BroadcastOp>(loc, frExp, filledTensor,
727+
.create<linalg::BroadcastOp>(loc, frExpVec[0], filledTensor,
728+
ArrayRef<int64_t>{0})
729+
->getResult(0);
730+
scalingFactors.push_back(broadcastScaleRes);
731+
732+
broadcastScaleRes =
733+
builder
734+
.create<linalg::BroadcastOp>(loc, frExpVec[1], filledTensor,
667735
ArrayRef<int64_t>{0})
668736
->getResult(0);
669-
return broadcastScaleRes;
737+
scalingFactors.push_back(broadcastScaleRes);
738+
739+
return scalingFactors;
670740
}
671741

672-
Value MLIRGenerator::quantizeGemm(LayerArgs &args, Value chain) {
742+
Value MLIRGenerator::quantizeGemm(LayerArgs &args, Value chain,
743+
Value scaleFactor) {
673744
Value input = args.input.value;
674745
Value weight = args.weight.value;
675-
Value output = args.output.value;
746+
Type outputType = quantType == QuantizationType::QuantDequant
747+
? args.intermediate.type
748+
: args.output.type;
676749

677750
auto inputShapedTy = cast<ShapedType>(input.getType());
678-
auto outputShapedTy = cast<ShapedType>(output.getType());
751+
auto outputShapedTy = cast<ShapedType>(outputType);
679752
auto shape = outputShapedTy.getShape();
680753
// Create a output type for the quantized output using shape and input element
681754
// type.
682755
auto contractOutputTy =
683756
RankedTensorType::get(shape, inputShapedTy.getElementType());
684-
context.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
757+
758+
auto castedOutput =
759+
builder.create<tensor::EmptyOp>(loc, outputShapedTy, ValueRange{});
685760
SmallVector<Attribute> maps;
686761
maps.push_back(AffineMapAttr::get(getMap(input, MAP_MATMUL_INPUT)));
687762
maps.push_back(AffineMapAttr::get(getMap(weight, MAP_MATMUL_WEIGHT)));
688-
maps.push_back(AffineMapAttr::get(getMap(output, MAP_MATMUL_OUTPUT)));
763+
maps.push_back(AffineMapAttr::get(getMap(castedOutput, MAP_MATMUL_OUTPUT)));
689764
auto dquantVal = getZeroInitTensor(contractOutputTy);
690-
Value scalingFactor =
691-
builder.create<tensor::EmptyOp>(loc, contractOutputTy, ValueRange{});
692-
scalingFactor = computeScalingFactor(&context, chain, scalingFactor);
693765

694766
auto dquantRes = builder
695767
.create<linalg::MulOp>(loc, chain.getType(),
696-
ValueRange{chain, scalingFactor},
768+
ValueRange{chain, scaleFactor},
697769
ValueRange{dquantVal})
698770
.getResult(0);
699771

700-
// Convert to integer type
701-
auto castedOutput =
702-
builder.create<tensor::EmptyOp>(loc, output.getType(), ValueRange{});
703772
dquantRes =
704773
builder
705774
.create<linalg::GenericOp>(
706-
loc, output.getType(), ValueRange{dquantRes},
775+
loc, outputShapedTy, ValueRange{dquantRes},
707776
ValueRange{castedOutput},
708777
ArrayRef<AffineMap>{getMap(dquantRes, MAP_PARALLEL),
709778
getMap(castedOutput, MAP_PARALLEL)},
@@ -1023,6 +1092,10 @@ TensorType MLIRGenerator::getShape(ArrayRef<int64_t> dims, PackingType type) {
10231092
return RankedTensorType::get(dims, dataTypes[2]);
10241093
} else if (type == PACK_OUTPUT) {
10251094
return RankedTensorType::get(dims, dataTypes[1]);
1095+
} else if (type == PACK_INPUT) {
1096+
return RankedTensorType::get(dims, dataTypes[0]);
1097+
} else if (type == PACK_INTERMEDIATE) {
1098+
return RankedTensorType::get(dims, dataTypes[1]);
10261099
}
10271100
}
10281101
// Unpacked type, just return 2D tensor
@@ -1070,6 +1143,8 @@ TensorType MLIRGenerator::getShape(ArrayRef<int64_t> dims, PackingType type) {
10701143
return RankedTensorType::get({n}, dataTypes[2]);
10711144
case WEIGHT_SCALE:
10721145
return RankedTensorType::get({k}, dataTypes[2]);
1146+
case PACK_INTERMEDIATE:
1147+
llvm_unreachable("Unknown intermediate packing type");
10731148
}
10741149

10751150
llvm_unreachable("Unknown packing type");

0 commit comments

Comments
 (0)