Skip to content

Commit 24f6118

Browse files
adrianlizarragaguschmue
authored andcommitted
[Quant Tool] Update QDQ Pad, Slice, Softmax (#22676)
### Description Updates python quantization tool: - Ensures QDQ Pad has equal quantization parameters across input and output for certain Pad configurations. - Ensures QDQ Slice always has equal quantization parameters across input and output. - Fixes bug when Softmax is _excluded_ from quantization. ### Motivation and Context QDQ Pad and Slice have lower latency on QNN EP when their quantization parameters are equal.
1 parent 0035c71 commit 24f6118

File tree

7 files changed

+461
-2
lines changed

7 files changed

+461
-2
lines changed

onnxruntime/python/tools/quantization/base_quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,4 +554,6 @@ def adjust_tensor_ranges(self):
554554
self.tensors_range[node.input[0]] = td
555555
# Adjust Softmax to range from 0.0 to 1.0
556556
elif node.op_type == "Softmax":
557+
if not self.should_quantize_node(node):
558+
continue
557559
self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0))

onnxruntime/python/tools/quantization/operators/pad.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
# --------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
from __future__ import annotations
6+
7+
from typing import Any
8+
9+
import numpy as np
110
import onnx
211

312
from ..quant_utils import (
@@ -8,6 +17,7 @@
817
quantize_nparray,
918
)
1019
from .base_operator import QuantOperatorBase
20+
from .qdq_base_operator import QDQOperatorBase
1121

1222

1323
class QPad(QuantOperatorBase):
@@ -98,3 +108,65 @@ def quantize(self):
98108
node.input[0] = quantized_input_value.q_name
99109
node.output[0] = quantized_output_value.q_name
100110
self.quantizer.new_nodes += [node]
111+
112+
113+
class QDQPad(QDQOperatorBase):
114+
def __init__(self, onnx_quantizer, onnx_node):
115+
super().__init__(onnx_quantizer, onnx_node)
116+
117+
def _get_pad_const_val(self, attrs_dict: dict[str, Any]) -> np.ndarray | None:
118+
"""
119+
Returns the Pad's constant padding value. Returns `None` if the padding value is
120+
not constant (i.e., comes from a dynamic input).
121+
"""
122+
const_val = None
123+
onnx_tensor_type = self.quantizer.model.get_tensor_type(self.node.input[0])
124+
if onnx_tensor_type is None:
125+
return None
126+
127+
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(onnx_tensor_type.elem_type)
128+
if self.quantizer.opset_version < 11:
129+
const_val = np.array(attrs_dict.get("value", 0), dtype=np_dtype)
130+
elif len(self.node.input) >= 3 and self.node.input[2]:
131+
const_val = self.quantizer.model.get_constant_value(self.node.input[2])
132+
else:
133+
const_val = np.array(0, dtype=np_dtype)
134+
135+
return const_val
136+
137+
def _should_quantize_output_same_as_input(self) -> bool:
138+
"""
139+
Returns true if Pad's output should use the same quantization parameters as input[0]
140+
"""
141+
attrs_dict = {}
142+
for attribute in self.node.attribute:
143+
kv = attribute_to_kwarg(attribute)
144+
attrs_dict.update(kv)
145+
146+
pad_mode = attrs_dict.get("mode", b"constant")
147+
if pad_mode in (b"reflect", b"edge", b"wrap"):
148+
# These modes pad the output with a value that already exists in the input.
149+
# So, we can quantize the output the same as the input.
150+
return True
151+
152+
# For 'constant' mode, if padding with 0, we can also quantize the output the same as the input
153+
# because our quantization floating-point range always includes 0.
154+
if pad_mode == b"constant":
155+
pad_val = self._get_pad_const_val(attrs_dict)
156+
if pad_val is not None and pad_val.dtype in (np.float32, np.float16):
157+
return float(pad_val.item()) == 0
158+
159+
return False
160+
161+
def quantize(self):
162+
assert self.node.op_type == "Pad"
163+
164+
for input_name in self.node.input:
165+
if input_name:
166+
self.quantizer.quantize_activation_tensor(input_name)
167+
168+
if not self.disable_qdq_for_node_output:
169+
if self._should_quantize_output_same_as_input():
170+
self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name)
171+
else:
172+
self.quantizer.quantize_activation_tensor(self.node.output[0])

onnxruntime/python/tools/quantization/registry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul
1515
from .operators.maxpool import QDQMaxPool, QMaxPool
1616
from .operators.norm import QDQNormalization
17-
from .operators.pad import QPad
17+
from .operators.pad import QDQPad, QPad
1818
from .operators.pooling import QLinearPool
1919
from .operators.qdq_base_operator import QDQOperatorBase
2020
from .operators.resize import QDQResize, QResize
@@ -76,6 +76,8 @@
7676
"Resize": QDQResize,
7777
"MaxPool": QDQMaxPool,
7878
"AveragePool": QDQDirect8BitOp,
79+
"Slice": QDQDirect8BitOp,
80+
"Pad": QDQPad,
7981
"MatMul": QDQMatMul,
8082
"Split": QDQSplit,
8183
"Gather": QDQGather,

onnxruntime/test/python/quantization/op_test_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
from __future__ import annotations
7+
18
import uuid
29
from pathlib import Path
310

@@ -661,3 +668,29 @@ def generate_random_initializer(initializer_name, tensor_shape, tensor_dtype, me
661668
tensor = np.random.normal(mean, dev, tensor_shape).astype(tensor_dtype)
662669
init = onnx.numpy_helper.from_array(tensor, initializer_name)
663670
return init
671+
672+
673+
def get_tensor_consumers_and_producers(
674+
model: onnx.ModelProto,
675+
) -> tuple[dict[str, list[onnx.NodeProto]], dict[str, onnx.NodeProto]]:
676+
"""
677+
Returns a tuple containing the following python dictionaries:
678+
- consumers: maps a tensor name to the list of nodes that have that tensor as an input.
679+
- producers: maps a tensor name to the node that generates this tensor as an output.
680+
"""
681+
consumers: dict[str, list[onnx.NodeProto]] = {}
682+
producers: dict[str, onnx.NodeProto] = {}
683+
for node in model.graph.node:
684+
# Iterate through node's inputs to build the consumers dictionary.
685+
for input_name in node.input:
686+
if input_name:
687+
if input_name not in consumers:
688+
consumers[input_name] = []
689+
690+
consumers[input_name].append(node)
691+
692+
# Iterate through node's outputs to build the producers dictionary.
693+
for output_name in node.output:
694+
producers[output_name] = node
695+
696+
return (consumers, producers)

onnxruntime/test/python/quantization/test_op_pad.py

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,23 @@
44
# Licensed under the MIT License. See License.txt in the project root for
55
# license information.
66
# --------------------------------------------------------------------------
7+
from __future__ import annotations
78

89
import itertools
10+
import os
11+
import tempfile
912
import unittest
1013

1114
import numpy as np
1215
import onnx
1316
from onnx import TensorProto, helper
14-
from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type
17+
from op_test_utils import (
18+
TestDataFeeds,
19+
check_model_correctness,
20+
check_op_type_count,
21+
check_qtype_by_node_type,
22+
get_tensor_consumers_and_producers,
23+
)
1524

1625
from onnxruntime.quantization import QuantFormat, QuantType, quantize_dynamic, quantize_static
1726

@@ -519,5 +528,159 @@ def test_pad_with_empty_string_input_name(self):
519528
self.assertNotEqual(name, "_quantized")
520529

521530

531+
class TestQDQPad(unittest.TestCase):
532+
@classmethod
533+
def setUpClass(cls):
534+
cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.pad_")
535+
536+
# Note: swap with the commented line if you want to see the models in local test dir.
537+
cls._tmp_dir_path = cls._tmp_model_dir.name
538+
# cls._tmp_dir_path = "."
539+
540+
@classmethod
541+
def tearDownClass(cls):
542+
cls._tmp_model_dir.cleanup()
543+
544+
def build_pad_model(
545+
self,
546+
mode: str,
547+
constant_value: float | None = None,
548+
opset: int = 21,
549+
float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT,
550+
) -> onnx.ModelProto:
551+
input_0 = onnx.helper.make_tensor_value_info("input_0", float_type, (3, 2))
552+
output_0 = onnx.helper.make_tensor_value_info("output_0", float_type, (3, 4))
553+
554+
initializers = []
555+
pad_input_names = ["input_0"]
556+
attrs = {"mode": mode}
557+
558+
pads_data = np.array([0, 2, 0, 0], dtype=np.int64) # Pad two vals at beginning of axis 1.
559+
if opset >= 11:
560+
initializers.append(onnx.numpy_helper.from_array(pads_data, "pads"))
561+
pad_input_names.append("pads")
562+
else:
563+
attrs["pads"] = pads_data.tolist()
564+
565+
if mode == "constant" and constant_value is not None:
566+
if opset >= 11:
567+
initializers.append(onnx.helper.make_tensor("constant_value", float_type, [], [constant_value]))
568+
pad_input_names.append("constant_value")
569+
else:
570+
attrs["value"] = float(constant_value)
571+
572+
pad_node = onnx.helper.make_node("Pad", pad_input_names, ["output_0"], name="Pad0", **attrs)
573+
574+
graph = onnx.helper.make_graph(
575+
[pad_node],
576+
"PadFloat",
577+
[input_0],
578+
[output_0],
579+
initializer=initializers,
580+
)
581+
opset_imports = [onnx.helper.make_opsetid("", opset)]
582+
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
583+
model = onnx.shape_inference.infer_shapes(model)
584+
onnx.checker.check_model(model, True)
585+
return model
586+
587+
def test_qdq_pad_qparams(self):
588+
"""
589+
Test that QDQ Pad has equal scale/zero-point for its input and output for certain configurations.
590+
"""
591+
test_configs = [
592+
# Opset 21
593+
("constant", None, 21, onnx.TensorProto.FLOAT),
594+
("constant", None, 21, onnx.TensorProto.FLOAT16),
595+
("constant", 0, 21, onnx.TensorProto.FLOAT),
596+
("constant", 0, 21, onnx.TensorProto.FLOAT16),
597+
("constant", 10.0, 21, onnx.TensorProto.FLOAT),
598+
("constant", 10.0, 21, onnx.TensorProto.FLOAT16),
599+
("reflect", None, 21, onnx.TensorProto.FLOAT),
600+
("reflect", None, 21, onnx.TensorProto.FLOAT16),
601+
("edge", None, 21, onnx.TensorProto.FLOAT),
602+
("edge", None, 21, onnx.TensorProto.FLOAT16),
603+
("wrap", None, 21, onnx.TensorProto.FLOAT),
604+
("wrap", None, 21, onnx.TensorProto.FLOAT16),
605+
# Model with opset 10 will use pad of opset 2, which uses attributes instead of inputs.
606+
# Opset 10 Q/DQ ops don't support float16.
607+
("constant", None, 10, onnx.TensorProto.FLOAT),
608+
("constant", 0, 10, onnx.TensorProto.FLOAT),
609+
("constant", 10.0, 10, onnx.TensorProto.FLOAT),
610+
("reflect", None, 10, onnx.TensorProto.FLOAT),
611+
("edge", None, 10, onnx.TensorProto.FLOAT),
612+
]
613+
614+
for pad_mode, constant_value, opset, float_type in test_configs:
615+
with self.subTest(pad_mode=pad_mode, constant_value=constant_value, opset=opset, float_type=float_type):
616+
label = f"_{pad_mode}_{constant_value}_opset{opset}_{onnx.TensorProto.DataType.Name(float_type)}"
617+
float_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.float.onnx")
618+
qdq_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.qdq.onnx")
619+
620+
float_model = self.build_pad_model(pad_mode, constant_value, opset=opset, float_type=float_type)
621+
onnx.save_model(float_model, float_model_path)
622+
623+
# Create a data reader
624+
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(float_type)
625+
input_data_list = [
626+
{"input_0": np.array([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], dtype=np_dtype)},
627+
{"input_0": np.array([[2.3, 3.4], [4.5, 5.7], [1.0, 1.2]], dtype=np_dtype)},
628+
]
629+
data_reader = TestDataFeeds(input_data_list)
630+
631+
# quantize model to QDQ
632+
quantize_static(
633+
float_model_path,
634+
qdq_model_path,
635+
data_reader,
636+
quant_format=QuantFormat.QDQ,
637+
activation_type=QuantType.QUInt8,
638+
weight_type=QuantType.QInt8,
639+
)
640+
641+
expected_op_counts = {"DequantizeLinear": 2, "QuantizeLinear": 2, "Pad": 1}
642+
if constant_value is not None and opset >= 11:
643+
expected_op_counts["DequantizeLinear"] += 1 # The constant padding value is quantized.
644+
check_op_type_count(self, qdq_model_path, **expected_op_counts)
645+
646+
if pad_mode != "reflect":
647+
# Do not check model correctness for 'reflect' mode because ONNX Runtime implementation does
648+
# not match the ONNX reference implementation. See the following issue:
649+
# https://github.com/microsoft/onnxruntime/issues/20801
650+
data_reader.rewind()
651+
check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next())
652+
653+
qdq_model = onnx.load_model(qdq_model_path)
654+
quant_output_same_as_input = False
655+
656+
if pad_mode in ("reflect", "edge", "wrap"):
657+
quant_output_same_as_input = True
658+
659+
if pad_mode == "constant" and constant_value in (None, 0):
660+
quant_output_same_as_input = True
661+
662+
pad_node = next((node for node in qdq_model.graph.node if node.op_type == "Pad"), None)
663+
self.assertNotEqual(pad_node, None)
664+
self.assertEqual(pad_node.op_type, "Pad")
665+
666+
# Get the parent and child nodes of the Pad and check that they are DQ/Q.
667+
consumers, producers = get_tensor_consumers_and_producers(qdq_model)
668+
input_dq_node = producers.get(pad_node.input[0], None)
669+
self.assertNotEqual(input_dq_node, None)
670+
self.assertEqual(input_dq_node.op_type, "DequantizeLinear")
671+
672+
output_q_node = consumers.get(pad_node.output[0], [None])[0]
673+
self.assertNotEqual(output_q_node, None)
674+
self.assertEqual(output_q_node.op_type, "QuantizeLinear")
675+
676+
# Check that the Pad's input DQ uses the same scale/zp as the Pad's output Q.
677+
if quant_output_same_as_input:
678+
self.assertEqual(input_dq_node.input[1], output_q_node.input[1]) # Same scale
679+
self.assertEqual(input_dq_node.input[2], output_q_node.input[2]) # Same zero-point
680+
else:
681+
self.assertNotEqual(input_dq_node.input[1], output_q_node.input[1])
682+
self.assertNotEqual(input_dq_node.input[2], output_q_node.input[2])
683+
684+
522685
if __name__ == "__main__":
523686
unittest.main()

0 commit comments

Comments
 (0)