-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Open
Labels
model:transformerissues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.quantizationissues related to quantizationissues related to quantization
Description
Describe the issue
Versions:
optimum==1.17.1
onnx==1.15.0
Attempting to quantize a new APISR model results in the following error:
WARNING:root:Please consider to run pre-processing before quantization. Refer to example: https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification/cpu/ReadMe.md
WARNING:root:Inference failed or unsupported type to quantize for tensor '/Slice_output_0', type is tensor_type {
elem_type: 7
shape {
dim {
dim_param: "unk__3"
}
dim {
dim_value: 2
}
}
}
.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[<ipython-input-29-cd2ec8e65af8>](https://localhost:8080/#) in <cell line: 6>()
4 )
5
----> 6 quantize_dynamic(
7 model_input='model.onnx',
8 model_output='output',
4 frames
[/usr/local/lib/python3.10/dist-packages/onnxruntime/quantization/quantize.py](https://localhost:8080/#) in quantize_dynamic(model_input, model_output, op_types_to_quantize, per_channel, reduce_range, weight_type, nodes_to_quantize, nodes_to_exclude, use_external_data_format, extra_options)
640 )
641
--> 642 quantizer.quantize_model()
643 quantizer.model.save_model_to_file(model_output, use_external_data_format)
644
[/usr/local/lib/python3.10/dist-packages/onnxruntime/quantization/onnx_quantizer.py](https://localhost:8080/#) in quantize_model(self)
401 number_of_existing_new_nodes = len(self.new_nodes)
402 op_quantizer = CreateOpQuantizer(self, node)
--> 403 op_quantizer.quantize()
404 for i in range(number_of_existing_new_nodes, len(self.new_nodes)):
405 for output_name in self.new_nodes[i].output:
[/usr/local/lib/python3.10/dist-packages/onnxruntime/quantization/operators/base_operator.py](https://localhost:8080/#) in quantize(self)
19 """
20 for _, node_input in enumerate(self.node.input):
---> 21 dequantize_node = self.quantizer._dequantize_value(node_input)
22 if dequantize_node is not None:
23 self.quantizer.new_nodes.append(dequantize_node)
[/usr/local/lib/python3.10/dist-packages/onnxruntime/quantization/onnx_quantizer.py](https://localhost:8080/#) in _dequantize_value(self, value_name)
1362 ):
1363 # axis is not specified so scale_init must be a scalar.
-> 1364 assert onnx.numpy_helper.to_array(scale_init).size == 1
1365
1366 dqlinear_name = value_name + "_DequantizeLinear"
[/usr/local/lib/python3.10/dist-packages/onnx/numpy_helper.py](https://localhost:8080/#) in to_array(tensor, base_dir)
182 arr: the converted array.
183 """
--> 184 if tensor.HasField("segment"):
185 raise ValueError("Currently not supporting loading segments.")
186 if tensor.data_type == TensorProto.UNDEFINED:
AttributeError: 'NoneType' object has no attribute 'HasField'
NOTE: Downgrading to onnxruntime<1.16.0
and onnx==1.13.1
seems to work.
To reproduce
-
Download this model:
wget https://huggingface.co/Xenova/4x_APISR_GRL_GAN_generator-onnx/resolve/main/onnx/model.onnx
-
Attempt to quantize the model
from onnxruntime.quantization import ( quantize_dynamic, QuantType ) quantize_dynamic( model_input='model.onnx', model_output='output', weight_type=QuantType.QUInt8, extra_options=dict( EnableSubgraph=True ), per_channel=True, reduce_range=True, )
Note: preprocessing doesn't fix this either:
python -m onnxruntime.quantization.preprocess --input model.onnx --output model-infer.onnx
produces
unsupported broadcast between floor(floor(Pad_50_o0__d2/8)*floor(Pad_50_o0__d3/8)*floor(Pad_50_o0__d0*floor(Pad_50_o0__d2/8 + Min(4, Pad_50_o0__d2)/8 - 1/2)*floor(Pad_50_o0__d3/8 + Min(4, Pad_50_o0__d3)/8 - 1/2)/(floor(Pad_50_o0__d2/8)*floor(Pad_50_o0__d3/8)))) Pad_50_o0__d0*floor(Pad_50_o0__d2/8 + Min(4, Pad_50_o0__d2)/8 - 1/2)*floor(Pad_50_o0__d3/8 + Min(4, Pad_50_o0__d3)/8 - 1/2)
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/usr/local/lib/python3.10/dist-packages/onnxruntime/quantization/preprocess.py", line 127, in <module>
quant_pre_process(
File "/usr/local/lib/python3.10/dist-packages/onnxruntime/quantization/shape_inference.py", line 71, in quant_pre_process
model = SymbolicShapeInference.infer_shapes(
File "/usr/local/lib/python3.10/dist-packages/onnxruntime/tools/symbolic_shape_infer.py", line 2859, in infer_shapes
raise Exception("Incomplete symbolic shape inference")
Exception: Incomplete symbolic shape inference
Urgency
Prevents me from upgrading to onnxruntime v1.17.1 to convert transformers.js models.
Platform
Linux
OS Version
Ubuntu 22.04.3 LTS
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.17.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Other / Unknown
Execution Provider Library Version
n/a
Metadata
Metadata
Assignees
Labels
model:transformerissues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.quantizationissues related to quantizationissues related to quantization