Skip to content

Commit 3ec8829

Browse files
qti-yuduoedgchen1
andauthored
Normalize to positive axis for comparison on GEMM (#26152)
GEMM has 2D input, when given `axis=-1` in DequantizeLinear node attributes. The check `axis == expected_axis` would fail. However, it sematically the same (-1 and 1 for 2D GEMM's input). Normalize the check so we can still perform the quantization for this case. --------- Co-authored-by: Edward Chen <[email protected]>
1 parent 2e8a45a commit 3ec8829

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "core/optimizer/qdq_transformer/weight_bias_quantization.h"
55

66
#include "core/common/common.h"
7+
#include "core/providers/common.h"
78
#include "core/util/qmath.h"
89
#include "core/graph/graph_utils.h"
910
#include "core/graph/graph_viewer.h"
@@ -128,6 +129,14 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph
128129
int64_t axis = 1;
129130
if (auto axis_iter = dq_attrs.find("axis"); axis_iter != dq_attrs.end()) {
130131
axis = axis_iter->second.i();
132+
const ONNX_NAMESPACE::TensorShapeProto* weight_shape = weight_arg->Shape();
133+
if (!weight_shape && dq_1->InputDefs()[0]) {
134+
weight_shape = dq_1->InputDefs()[0]->Shape();
135+
}
136+
if (axis < 0 && !weight_shape) {
137+
continue;
138+
}
139+
axis = HandleNegativeAxis(axis, weight_shape->dim_size());
131140
}
132141

133142
int64_t expected_axis = 0;

onnxruntime/test/optimizer/qdq_transformer_test.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5669,6 +5669,52 @@ TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Weight) {
56695669
test_case(true);
56705670
}
56715671

5672+
TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_HandleNegativeDqAxis) {
5673+
auto test_case = [](bool use_contrib_qdq) {
5674+
auto build_test_case = [&](ModelTestBuilder& builder) {
5675+
NodeArg* input_arg =
5676+
builder.MakeInput<uint8_t>({2, 16}, std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
5677+
NodeArg* weight_arg = builder.MakeInitializer<uint8_t>({16, 16}, std::numeric_limits<uint8_t>::min(),
5678+
std::numeric_limits<uint8_t>::max());
5679+
NodeArg* bias_arg = builder.MakeInitializer<float>({16}, -0.1f, 0.1f);
5680+
5681+
NodeArg* input_dq_arg = builder.MakeIntermediate();
5682+
NodeArg* weight_dq_arg = builder.MakeIntermediate();
5683+
NodeArg* gemm_dq_arg = builder.MakeIntermediate();
5684+
NodeArg* output_arg = builder.MakeOutput();
5685+
5686+
builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.001f, static_cast<uint8_t>(0), input_dq_arg, use_contrib_qdq);
5687+
5688+
// Per-channel quantized weight with negative axis as DQ attribute
5689+
std::vector<float> scales = std::vector<float>(16, 0.05f);
5690+
std::vector<uint8_t> zp = std::vector<uint8_t>(16, static_cast<uint8_t>(0));
5691+
auto& dq_node = builder.AddDequantizeLinearNode<uint8_t>(weight_arg, scales, zp, weight_dq_arg, nullptr, use_contrib_qdq);
5692+
dq_node.AddAttribute("axis", static_cast<int64_t>(-1));
5693+
5694+
builder.AddNode("Gemm", {input_dq_arg, weight_dq_arg, bias_arg}, {gemm_dq_arg});
5695+
builder.AddQuantizeLinearNode<uint8_t>(gemm_dq_arg, 0.144f, static_cast<uint8_t>(69), output_arg, use_contrib_qdq);
5696+
};
5697+
5698+
auto check_transformed_graph = [](InferenceSessionWrapper& session) {
5699+
auto op_to_count = CountOpsInGraph(session.GetGraph());
5700+
EXPECT_EQ(op_to_count["QuantizeLinear"] + op_to_count["com.microsoft.QuantizeLinear"], 1);
5701+
EXPECT_EQ(op_to_count["DequantizeLinear"] + op_to_count["com.microsoft.DequantizeLinear"], 2 + 1);
5702+
};
5703+
5704+
TransformerTester(build_test_case,
5705+
check_transformed_graph,
5706+
TransformerLevel::Default,
5707+
TransformerLevel::Level1,
5708+
/*opset_version=*/20,
5709+
/*per_sample_tolerance=*/0.01,
5710+
/*relative_per_sample_tolerance=*/0.01,
5711+
/*transformer=*/std::make_unique<WeightBiasQuantization>());
5712+
};
5713+
5714+
test_case(false);
5715+
test_case(true);
5716+
}
5717+
56725718
TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Weight_Bias) {
56735719
auto test_case = [](bool use_contrib_qdq) {
56745720
auto build_test_case = [&](ModelTestBuilder& builder) {

0 commit comments

Comments
 (0)