19
19
20
20
#include " MLIRGen.h"
21
21
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
22
+ // #include "mlir/IR/Value.h"
22
23
#include " llvm/Support/ErrorHandling.h"
23
24
#include " llvm/Support/raw_ostream.h"
24
25
@@ -154,6 +155,7 @@ MLIRGenerator::MLIRGenerator(StringRef outputOpKindStr, StringRef kernelStr,
154
155
llvm::StringSwitch<std::optional<QuantizationType>>(quantizationTypeStr)
155
156
.CaseLower (" quantize" , QuantizationType::Quant)
156
157
.CaseLower (" dequantize" , QuantizationType::Dequant)
158
+ .CaseLower (" testquant" , QuantizationType::QuantDequant)
157
159
.Default (QuantizationType::None);
158
160
quantType = *optQuantType;
159
161
@@ -210,20 +212,68 @@ void MLIRGenerator::getKernelTypes(KernelArgs &args) {
210
212
if (quantType == QuantizationType::Dequant)
211
213
arg.weightScale .type = getShape ({outputSize}, WEIGHT_SCALE);
212
214
arg.bias .type = getShape ({outputSize}, PACK_OUTPUT);
213
- arg.output .type = getShape ({batch, outputSize}, PACK_OUTPUT);
215
+
216
+ // For QuantDequant, such as F32->i8->F32, we need an intermediate type to
217
+ // hold the quantized value.
218
+ if (quantType == QuantizationType::QuantDequant) {
219
+ arg.intermediate .type = getShape ({batch, outputSize}, PACK_INTERMEDIATE);
220
+ arg.output .type = getShape ({batch, outputSize}, PACK_INPUT);
221
+ } else {
222
+ arg.output .type = getShape ({batch, outputSize}, PACK_OUTPUT);
223
+ }
214
224
args.push_back (arg);
215
225
216
226
// Update next input type with the output type of this layer
217
227
currentType = arg.output .type ;
218
228
}
219
229
}
220
230
231
+ // Creates a quantize op around the gemm output and subsequently dequantize it.
232
+ // This is mainly to validate the quantization scheme.
233
+ Value MLIRGenerator::testQuantDequant (LayerArgs &args, Value input) {
234
+ SmallVector<Value> scalingFactors = computeScalingFactor (input);
235
+ Value chain = quantizeGemm (args, input, scalingFactors[0 ]);
236
+ Value reScaleFactor = scalingFactors[1 ];
237
+ Type rescaleType = reScaleFactor.getType ();
238
+ auto castedOutput =
239
+ builder.create <tensor::EmptyOp>(loc, rescaleType, ValueRange{});
240
+ Value castedVal =
241
+ builder
242
+ .create <linalg::GenericOp>(
243
+ loc, rescaleType, ValueRange{chain}, ValueRange{castedOutput},
244
+ ArrayRef<AffineMap>{getMap (chain, MAP_PARALLEL),
245
+ getMap (castedOutput, MAP_PARALLEL)},
246
+ getIterators (MAP_PARALLEL),
247
+ [&](OpBuilder &nestedBuilder, Location nestedLoc,
248
+ ValueRange blockArgs) {
249
+ auto arg0 = blockArgs[0 ];
250
+ auto casted = nestedBuilder.create <arith::SIToFPOp>(
251
+ loc, dataTypes[2 ], arg0);
252
+ nestedBuilder.create <linalg::YieldOp>(loc, ValueRange{casted});
253
+ })
254
+ .getResult (0 );
255
+ castedVal = builder
256
+ .create <linalg::MulOp>(loc, TypeRange{castedOutput.getType ()},
257
+ ValueRange{castedVal, reScaleFactor},
258
+ ValueRange{castedOutput})
259
+ .getResult (0 );
260
+ return castedVal;
261
+ }
262
+
221
263
Value MLIRGenerator::createLayer (LayerArgs &args, bool hasMixedType) {
222
264
OpBuilder::InsertionGuard guard (builder);
223
265
224
266
Value chain;
225
267
chain = lowerMatmul (args, hasMixedType);
226
268
269
+ if (quantType == QuantizationType::QuantDequant)
270
+ return testQuantDequant (args, chain);
271
+
272
+ if (quantType == QuantizationType::Quant) {
273
+ SmallVector<Value> scalingFactors = computeScalingFactor (chain);
274
+ chain = quantizeGemm (args, chain, scalingFactors[0 ]);
275
+ }
276
+
227
277
if (quantType == QuantizationType::Dequant)
228
278
chain = dequantizeGemm (args, chain);
229
279
@@ -236,9 +286,6 @@ Value MLIRGenerator::createLayer(LayerArgs &args, bool hasMixedType) {
236
286
chain = lowerNamedRelu (chain, args.output .value );
237
287
}
238
288
239
- if (quantType == QuantizationType::Quant)
240
- chain = quantizeGemm (args, chain);
241
-
242
289
// Last layer may output softmax
243
290
if (args.index == layers.size () - 1 ) {
244
291
if (outputOpKind == OutputOpKind::Generic) {
@@ -577,8 +624,7 @@ Value MLIRGenerator::lowerContract(Value input, Value weight, Value output) {
577
624
return contract;
578
625
}
579
626
580
- Value MLIRGenerator::computeScalingFactor (MLIRContext *ctx, Value input,
581
- Value scale) {
627
+ SmallVector<Value> MLIRGenerator::computeScalingFactor (Value input) {
582
628
auto inputType = cast<ShapedType>(input.getType ());
583
629
assert (inputType.getRank () == 2 && " Input must be a 2D tensor" );
584
630
@@ -617,93 +663,116 @@ Value MLIRGenerator::computeScalingFactor(MLIRContext *ctx, Value input,
617
663
// Compute the scaling factors (2^(-exponent)) from the absolute maximum
618
664
// values.
619
665
Value zeroVal = builder.create <arith::ConstantIntOp>(loc, 0 , 32 );
620
- Value zeroFloat =
621
- builder.create <arith::ConstantOp>(loc, builder.getF32FloatAttr (0 .0f ));
622
- Value channleScale =
666
+
667
+ // Create two output tensors for the two results
668
+ context.getOrLoadDialect <mlir::LLVM::LLVMDialect>();
669
+ Value channelScale =
623
670
builder.create <tensor::EmptyOp>(loc, reductionType, ValueRange{});
624
- Value filledchannleScale =
625
- builder.create <linalg::FillOp>(loc, zeroFloat, channleScale).getResult (0 );
626
- Value frExp =
627
- builder
628
- .create <linalg::GenericOp>(
629
- loc, reductionType, ValueRange{absMax},
630
- ValueRange{filledchannleScale},
631
- ArrayRef<AffineMap>{getMap (absMax, MAP_PARALLEL),
632
- getMap (filledchannleScale, MAP_PARALLEL)},
633
- ArrayRef<utils::IteratorType>{utils::IteratorType::parallel},
634
- [&](OpBuilder &nestedBuilder, Location nestedLoc,
635
- ValueRange args) {
636
- Value frexpResult = LLVM::FractionExpOp::create (
637
- nestedBuilder, nestedLoc,
638
- LLVM::LLVMStructType::getLiteral (
639
- ctx, ArrayRef<Type>{elementType, builder.getI32Type ()}),
640
- ValueRange{args[0 ]});
641
- Value exponent = LLVM::ExtractValueOp::create (
642
- nestedBuilder, nestedLoc,
643
- builder.getI32Type (), frexpResult, 1 )
644
- .getResult ();
645
- Value unbiased = nestedBuilder.create <arith::SubIOp>(
646
- nestedLoc, exponent,
647
- builder.create <arith::ConstantOp>(
648
- nestedLoc, builder.getI32IntegerAttr (7 )));
649
- Value negExponent = nestedBuilder.create <arith::SubIOp>(
650
- nestedLoc, zeroVal, unbiased);
651
- auto tchannleScale =
652
- nestedBuilder
653
- .create <math::Exp2Op>(
654
- nestedLoc, nestedBuilder.create <arith::SIToFPOp>(
655
- nestedLoc, elementType, negExponent))
656
- ->getResult (0 );
657
- nestedBuilder.create <linalg::YieldOp>(nestedLoc, tchannleScale);
658
- })
659
- .getResult (0 );
671
+ Value channelReScale =
672
+ builder.create <tensor::EmptyOp>(loc, reductionType, ValueRange{});
673
+
674
+ auto frExp = builder.create <linalg::GenericOp>(
675
+ loc,
676
+ TypeRange{reductionType, reductionType}, // Specify multiple result types
677
+ ValueRange{absMax}, ValueRange{channelScale, channelReScale},
678
+ ArrayRef<AffineMap>{getMap (absMax, MAP_PARALLEL),
679
+ getMap (channelScale, MAP_PARALLEL),
680
+ getMap (channelReScale, MAP_PARALLEL)},
681
+ ArrayRef<utils::IteratorType>{utils::IteratorType::parallel},
682
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
683
+ Value frexpResult = LLVM::FractionExpOp::create (
684
+ nestedBuilder, nestedLoc,
685
+ LLVM::LLVMStructType::getLiteral (
686
+ &context, ArrayRef<Type>{elementType, builder.getI32Type ()}),
687
+ ValueRange{args[0 ]});
688
+ Value exponent =
689
+ LLVM::ExtractValueOp::create (nestedBuilder, nestedLoc,
690
+ builder.getI32Type (), frexpResult, 1 )
691
+ .getResult ();
692
+ Value unbiased = nestedBuilder.create <arith::SubIOp>(
693
+ nestedLoc, exponent,
694
+ builder.create <arith::ConstantOp>(nestedLoc,
695
+ builder.getI32IntegerAttr (7 )));
696
+ Value negExponent =
697
+ nestedBuilder.create <arith::SubIOp>(nestedLoc, zeroVal, unbiased);
698
+ auto tchannleReScale =
699
+ nestedBuilder
700
+ .create <math::Exp2Op>(nestedLoc,
701
+ nestedBuilder.create <arith::SIToFPOp>(
702
+ nestedLoc, elementType, unbiased))
703
+ ->getResult (0 );
704
+ auto tchannleScale =
705
+ nestedBuilder
706
+ .create <math::Exp2Op>(nestedLoc,
707
+ nestedBuilder.create <arith::SIToFPOp>(
708
+ nestedLoc, elementType, negExponent))
709
+ ->getResult (0 );
710
+ nestedBuilder.create <linalg::YieldOp>(
711
+ nestedLoc, ValueRange{tchannleScale, tchannleReScale});
712
+ });
713
+
714
+ SmallVector<Value> frExpVec;
715
+ frExpVec.push_back (frExp.getResults ()[0 ]);
716
+ frExpVec.push_back (frExp.getResults ()[1 ]);
660
717
718
+ SmallVector<Value> scalingFactors;
719
+ Value scalingFactor =
720
+ builder.create <tensor::EmptyOp>(loc, inputType, ValueRange{});
661
721
Value filledTensor =
662
- builder.create <linalg::FillOp>(loc, initValue, scale).getResult (0 );
722
+ builder.create <linalg::FillOp>(loc, initValue, scalingFactor)
723
+ .getResult (0 );
663
724
// Broadcast to match output shape
664
725
auto broadcastScaleRes =
665
726
builder
666
- .create <linalg::BroadcastOp>(loc, frExp, filledTensor,
727
+ .create <linalg::BroadcastOp>(loc, frExpVec[0 ], filledTensor,
728
+ ArrayRef<int64_t >{0 })
729
+ ->getResult (0 );
730
+ scalingFactors.push_back (broadcastScaleRes);
731
+
732
+ broadcastScaleRes =
733
+ builder
734
+ .create <linalg::BroadcastOp>(loc, frExpVec[1 ], filledTensor,
667
735
ArrayRef<int64_t >{0 })
668
736
->getResult (0 );
669
- return broadcastScaleRes;
737
+ scalingFactors.push_back (broadcastScaleRes);
738
+
739
+ return scalingFactors;
670
740
}
671
741
672
- Value MLIRGenerator::quantizeGemm (LayerArgs &args, Value chain) {
742
+ Value MLIRGenerator::quantizeGemm (LayerArgs &args, Value chain,
743
+ Value scaleFactor) {
673
744
Value input = args.input .value ;
674
745
Value weight = args.weight .value ;
675
- Value output = args.output .value ;
746
+ Type outputType = quantType == QuantizationType::QuantDequant
747
+ ? args.intermediate .type
748
+ : args.output .type ;
676
749
677
750
auto inputShapedTy = cast<ShapedType>(input.getType ());
678
- auto outputShapedTy = cast<ShapedType>(output. getType () );
751
+ auto outputShapedTy = cast<ShapedType>(outputType );
679
752
auto shape = outputShapedTy.getShape ();
680
753
// Create a output type for the quantized output using shape and input element
681
754
// type.
682
755
auto contractOutputTy =
683
756
RankedTensorType::get (shape, inputShapedTy.getElementType ());
684
- context.getOrLoadDialect <mlir::LLVM::LLVMDialect>();
757
+
758
+ auto castedOutput =
759
+ builder.create <tensor::EmptyOp>(loc, outputShapedTy, ValueRange{});
685
760
SmallVector<Attribute> maps;
686
761
maps.push_back (AffineMapAttr::get (getMap (input, MAP_MATMUL_INPUT)));
687
762
maps.push_back (AffineMapAttr::get (getMap (weight, MAP_MATMUL_WEIGHT)));
688
- maps.push_back (AffineMapAttr::get (getMap (output , MAP_MATMUL_OUTPUT)));
763
+ maps.push_back (AffineMapAttr::get (getMap (castedOutput , MAP_MATMUL_OUTPUT)));
689
764
auto dquantVal = getZeroInitTensor (contractOutputTy);
690
- Value scalingFactor =
691
- builder.create <tensor::EmptyOp>(loc, contractOutputTy, ValueRange{});
692
- scalingFactor = computeScalingFactor (&context, chain, scalingFactor);
693
765
694
766
auto dquantRes = builder
695
767
.create <linalg::MulOp>(loc, chain.getType (),
696
- ValueRange{chain, scalingFactor },
768
+ ValueRange{chain, scaleFactor },
697
769
ValueRange{dquantVal})
698
770
.getResult (0 );
699
771
700
- // Convert to integer type
701
- auto castedOutput =
702
- builder.create <tensor::EmptyOp>(loc, output.getType (), ValueRange{});
703
772
dquantRes =
704
773
builder
705
774
.create <linalg::GenericOp>(
706
- loc, output. getType () , ValueRange{dquantRes},
775
+ loc, outputShapedTy , ValueRange{dquantRes},
707
776
ValueRange{castedOutput},
708
777
ArrayRef<AffineMap>{getMap (dquantRes, MAP_PARALLEL),
709
778
getMap (castedOutput, MAP_PARALLEL)},
@@ -1023,6 +1092,10 @@ TensorType MLIRGenerator::getShape(ArrayRef<int64_t> dims, PackingType type) {
1023
1092
return RankedTensorType::get (dims, dataTypes[2 ]);
1024
1093
} else if (type == PACK_OUTPUT) {
1025
1094
return RankedTensorType::get (dims, dataTypes[1 ]);
1095
+ } else if (type == PACK_INPUT) {
1096
+ return RankedTensorType::get (dims, dataTypes[0 ]);
1097
+ } else if (type == PACK_INTERMEDIATE) {
1098
+ return RankedTensorType::get (dims, dataTypes[1 ]);
1026
1099
}
1027
1100
}
1028
1101
// Unpacked type, just return 2D tensor
@@ -1070,6 +1143,8 @@ TensorType MLIRGenerator::getShape(ArrayRef<int64_t> dims, PackingType type) {
1070
1143
return RankedTensorType::get ({n}, dataTypes[2 ]);
1071
1144
case WEIGHT_SCALE:
1072
1145
return RankedTensorType::get ({k}, dataTypes[2 ]);
1146
+ case PACK_INTERMEDIATE:
1147
+ llvm_unreachable (" Unknown intermediate packing type" );
1073
1148
}
1074
1149
1075
1150
llvm_unreachable (" Unknown packing type" );
0 commit comments