Skip to content

ONNX Runtime static quantization of YOLO11 results in zero mAP metrics #528

@yuguerten

Description

@yuguerten

I am trying to quantize the YOLO11 model with onnxruntime, and I have conducted a comparative study of the mAP across all the models I possess: ".pt, onnx, onnx+dynamic, and onnx+static". What I found strange is that in onnx+static quantization, all the YOLO-based metrics are zero (Image). Here is my code:

pt_model = YOLO("best.pt")
dataset = "data.yaml"
PATH = "workplace/models"
workdir = os.path.dirname(PATH)
path_onnx_model = os.path.join(workdir, "best.onnx")
path_preprocess_model = os.path.join(workdir, "prep.onnx")
path_static_quantized_model = os.path.join(workdir, "static.onnx")
path_dynamic_quantized_model = os.path.join(workdir, "dynamic.onnx")
model = YOLO(pt_model, task='detect', verbose=False)
baseline_metrics = model.val(
    data=dataset,
    imgsz=768,
    iou=0.3,
    conf=0.5,
    batch=16,
    device=device
)
pt_model.export(
    format="onnx",
    imgsz=768,
    dynamic=True,
    nms =True,
    conf=0.5,
    iou=0.3,
    batch=16, # expected faster inference
    data=dataset,
    device=device,
)
model = YOLO(path_onnx_model, task='detect', verbose=False)
onnx_metrics = model.val(
    data=dataset,
    imgsz=768,
    iou=0.3,
    conf=0.5,
    project=workdir,
    batch=16,
    device=device
)
shape_inference.quant_pre_process(
   path_onnx_model,
   path_preprocess_model,
   skip_optimization=False,
   skip_onnx_shape=False,
   skip_symbolic_shape=True,
   auto_merge=False,
   int_max=2**31 - 1,
   guess_output_rank=False,
   verbose=0,
   save_as_external_data=False,
   all_tensors_to_one_file=False,
   external_data_location="./",
   external_data_size_threshold=1024)
class ImageFolderCalibrationDataReader(CalibrationDataReader):
    def __init__(self, folder_path, input_name, input_height, input_width):
        super().__init__()
        self.folder_path = folder_path
        self.input_name = input_name
        self.input_height = input_height
        self.input_width = input_width
        self.calibration_images = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        self.current_item = 0

    def get_next(self) -> dict:
        if self.current_item >= len(self.calibration_images):
            return None

        image_path = os.path.join(self.folder_path, self.calibration_images[self.current_item])
        image = cv2.imread(image_path)
        if image is None:
            self.current_item += 1
            return self.get_next()
            
        image = cv2.resize(image, (self.input_width, self.input_height))
        image = image.astype('float32') / 255.0  # Normalize to [0, 1]
        image = np.transpose(image, (2, 0, 1))  # HWC -> CHW
        image = np.expand_dims(image, axis=0)  # Add batch dimension

        self.current_item += 1
        return {self.input_name: image}

    def __len__(self) -> int:
        return len(self.calibration_images)
calibration_data_path = "images"
input_name = 'images'
input_shape = (768, 768, 3)
calibration_dataset = ImageFolderCalibrationDataReader(
    calibration_data_path, 
    input_name, 
    input_shape[0], 
    input_shape[1]
)
quantize_static(
            path_preprocess_model,
            path_static_quantized_model,
            calibration_dataset,
            quant_format=QuantFormat.QDQ,
            weight_type=QuantType.QInt8,
            activation_type=QuantType.QInt8,
            per_channel=False,
            reduce_range=False,
        )
model = YOLO(path_static_quantized_model, task='detect', verbose=False)
static_metrics = model.val(
    data=dataset,
    imgsz=768,
    iou=0.3,
    conf=0.5,
    project=workdir,
    batch=16
)
import onnx
model = onnx.load(path_preprocess_model)
nodes_to_exclude = [
    node.name for node in model.graph.node if node.op_type == "Conv"
]
quantize_dynamic(
    model_input=path_preprocess_model,
    model_output=path_dynamic_quantized_model,
    weight_type=QuantType.QInt8,
    nodes_to_exclude = nodes_to_exclude
)
model = YOLO(path_dynamic_quantized_model, task='detect', verbose=False)
dynamic_metrics = model.val(
    data=dataset,
    imgsz=768,
    iou=0.3,
    conf=0.5,
    project=workdir,
    batch=16
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions