@@ -5669,6 +5669,52 @@ TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Weight) {
5669
5669
test_case (true );
5670
5670
}
5671
5671
5672
+ TEST (QDQTransformerTests, WeightBiasQuantization_Gemm_HandleNegativeDqAxis) {
5673
+ auto test_case = [](bool use_contrib_qdq) {
5674
+ auto build_test_case = [&](ModelTestBuilder& builder) {
5675
+ NodeArg* input_arg =
5676
+ builder.MakeInput <uint8_t >({2 , 16 }, std::numeric_limits<uint8_t >::min (), std::numeric_limits<uint8_t >::max ());
5677
+ NodeArg* weight_arg = builder.MakeInitializer <uint8_t >({16 , 16 }, std::numeric_limits<uint8_t >::min (),
5678
+ std::numeric_limits<uint8_t >::max ());
5679
+ NodeArg* bias_arg = builder.MakeInitializer <float >({16 }, -0 .1f , 0 .1f );
5680
+
5681
+ NodeArg* input_dq_arg = builder.MakeIntermediate ();
5682
+ NodeArg* weight_dq_arg = builder.MakeIntermediate ();
5683
+ NodeArg* gemm_dq_arg = builder.MakeIntermediate ();
5684
+ NodeArg* output_arg = builder.MakeOutput ();
5685
+
5686
+ builder.AddDequantizeLinearNode <uint8_t >(input_arg, 0 .001f , static_cast <uint8_t >(0 ), input_dq_arg, use_contrib_qdq);
5687
+
5688
+ // Per-channel quantized weight with negative axis as DQ attribute
5689
+ std::vector<float > scales = std::vector<float >(16 , 0 .05f );
5690
+ std::vector<uint8_t > zp = std::vector<uint8_t >(16 , static_cast <uint8_t >(0 ));
5691
+ auto & dq_node = builder.AddDequantizeLinearNode <uint8_t >(weight_arg, scales, zp, weight_dq_arg, nullptr , use_contrib_qdq);
5692
+ dq_node.AddAttribute (" axis" , static_cast <int64_t >(-1 ));
5693
+
5694
+ builder.AddNode (" Gemm" , {input_dq_arg, weight_dq_arg, bias_arg}, {gemm_dq_arg});
5695
+ builder.AddQuantizeLinearNode <uint8_t >(gemm_dq_arg, 0 .144f , static_cast <uint8_t >(69 ), output_arg, use_contrib_qdq);
5696
+ };
5697
+
5698
+ auto check_transformed_graph = [](InferenceSessionWrapper& session) {
5699
+ auto op_to_count = CountOpsInGraph (session.GetGraph ());
5700
+ EXPECT_EQ (op_to_count[" QuantizeLinear" ] + op_to_count[" com.microsoft.QuantizeLinear" ], 1 );
5701
+ EXPECT_EQ (op_to_count[" DequantizeLinear" ] + op_to_count[" com.microsoft.DequantizeLinear" ], 2 + 1 );
5702
+ };
5703
+
5704
+ TransformerTester (build_test_case,
5705
+ check_transformed_graph,
5706
+ TransformerLevel::Default,
5707
+ TransformerLevel::Level1,
5708
+ /* opset_version=*/ 20 ,
5709
+ /* per_sample_tolerance=*/ 0.01 ,
5710
+ /* relative_per_sample_tolerance=*/ 0.01 ,
5711
+ /* transformer=*/ std::make_unique<WeightBiasQuantization>());
5712
+ };
5713
+
5714
+ test_case (false );
5715
+ test_case (true );
5716
+ }
5717
+
5672
5718
TEST (QDQTransformerTests, WeightBiasQuantization_Gemm_Weight_Bias) {
5673
5719
auto test_case = [](bool use_contrib_qdq) {
5674
5720
auto build_test_case = [&](ModelTestBuilder& builder) {
0 commit comments