Skip to content

Problems trying to convert & run whisper model #1611

@RyanMetcalfeInt8

Description

@RyanMetcalfeInt8

For whisper, there seem to be both conversion issues, and GenAI runtime errors.

onnxruntime:

commit a3c3e2f038216d03a184e76ac32e452612a45e33 (HEAD -> main, origin/main, origin/HEAD)
Author: Yulong Wang <[email protected]>
Date:   Mon Jul 7 02:02:43 2025 -0700

onnxruntime-genai:

commit 2f2ad90b7e5c132d2a6aa986031305ac7898ef48 (HEAD -> main, origin/main, origin/HEAD)
Author: Kinfey <[email protected]>
Date:   Tue Jul 8 03:01:22 2025 +0800

Conversion Issues

I tried following instructions from here: https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/whisper#exporting-whisper-for-onnx-runtime-genai

For this command:
python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision fp32 --provider cpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk

I received this runtime error:
(Python 3.10, Ubuntu 24.04)

Arguments:Namespace(model_name_or_path='openai/whisper-large-v3-turbo', model_impl='hf', cache_dir='./cache_models', output='whisper-turbo', optimize_onnx=False, use_gpu=False, precision=<Precision.FLOAT32: 'fp32'>, use_int64_inputs=False, provider='cpu', verbose=False, use_external_data_format=True, overwrite=False, separate_encoder_and_decoder_init=False, no_beam_search_op=True, use_decoder_masked_mha=False, use_vocab_mask=False, use_prefix_vocab_mask=False, use_forced_decoder_ids=False, use_logits_processor=False, collect_cross_qk=True, extra_decoding_ids=False, use_temperature=False, no_repeat_ngram_size=0, output_sequence_scores=False, output_scores=False, output_cross_qk=True, cross_qk_onnx_model=None, output_no_speech_probs=False, quantize_embedding_layer=False, quantize_per_channel=False, quantize_reduce_range=False, use_specific_logits_processor=False)
PyTorch Version:2.7.1+cu126
Transformers Version:4.53.1
OnnxRuntime Version:1.22.0
Skip exporting: existing ONNX model whisper-turbo/whisper-large-v3-turbo_decoder.onnx
Exporting ONNX model to whisper-turbo/whisper-large-v3-turbo_encoder.onnx
/home/rdmetca/Whisper/onnxruntime/onnxruntime/python/tools/transformers/models/whisper/my_env_310/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py:676: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if input_features.shape[-1] != expected_seq_length:
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/rdmetca/Whisper/onnxruntime/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py", line 573, in <module>
    main()
  File "/home/rdmetca/Whisper/onnxruntime/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py", line 479, in main
    output_paths = export_onnx_models(
  File "/home/rdmetca/Whisper/onnxruntime/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py", line 381, in export_onnx_models
    WhisperHelper.export_onnx(
  File "/home/rdmetca/Whisper/onnxruntime/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py", line 728, in export_onnx
    model.export_onnx(
  File "/home/rdmetca/Whisper/onnxruntime/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py", line 288, in export_onnx
    torch.onnx.export(
  File "/home/rdmetca/Whisper/onnxruntime/onnxruntime/python/tools/transformers/models/whisper/my_env_310/lib/python3.10/site-packages/torch/onnx/__init__.py", line 396, in export
    export(
  File "/home/rdmetca/Whisper/onnxruntime/onnxruntime/python/tools/transformers/models/whisper/my_env_310/lib/python3.10/site-packages/torch/onnx/utils.py", line 529, in export
    _export(
  File "/home/rdmetca/Whisper/onnxruntime/onnxruntime/python/tools/transformers/models/whisper/my_env_310/lib/python3.10/site-packages/torch/onnx/utils.py", line 1467, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/rdmetca/Whisper/onnxruntime/onnxruntime/python/tools/transformers/models/whisper/my_env_310/lib/python3.10/site-packages/torch/onnx/utils.py", line 1143, in _model_to_graph
    _set_input_and_output_names(graph, input_names, output_names)
  File "/home/rdmetca/Whisper/onnxruntime/onnxruntime/python/tools/transformers/models/whisper/my_env_310/lib/python3.10/site-packages/torch/onnx/utils.py", line 1596, in _set_input_and_output_names
    set_names(list(graph.outputs()), output_names, "output")
  File "/home/rdmetca/Whisper/onnxruntime/onnxruntime/python/tools/transformers/models/whisper/my_env_310/lib/python3.10/site-packages/torch/onnx/utils.py", line 1573, in set_names
    raise RuntimeError(
RuntimeError: number of output names provided (65) exceeded number of outputs (9)
========> Handling decoder model......
========> Handling encoder model......

I also received the same error on Windows, using python 3.10.

When I replaced openai/whisper-large-v3-turbo with openai/whisper-base, it seemed to work okay on Ubuntu. But on Windows, it seemed to fail during some ninja / build post-conversion step.

Whisper GenAI example Issues

Since I was able to produce a whisper-base model, I continued on to try to run the sample. I seem to get some reshape errors:

(ortgenai_python_env) C:\Users\LocalAdmin\ort_build\BuildArtifacts-20250707_111001\onnxruntime-genai\examples\python>python whisper.py -m C:\Users\LocalAdmin\ort_build\WhisperModels\whisper-base
Loading model...
Audio Paths (comma separated): C:\Users\LocalAdmin\ort_build\BuildArtifacts-20250707_111001\onnxruntime-genai\test\test_models\audios\jfk.flac
Loading audio...
Processing audio...
2025-07-08 07:07:27.7721011 [E:onnxruntime:onnxruntime-genai, sequential_executor.cc:572 onnxruntime::ExecuteKernel] Non-zero status code returned while running Reshape node. Name:'/decoder/Reshape_1' Status Message: C:\Users\LocalAdmin\ort_build\BuildArtifacts-20250707_111001\onnxruntime\onnxruntime\core\providers\cpu\tensor\reshape_helper.h:47 onnxruntime::ReshapeHelper::ReshapeHelper input_shape_size == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{4}, requested shape:{1,1}

Traceback (most recent call last):
  File "C:\Users\LocalAdmin\ort_build\BuildArtifacts-20250707_111001\onnxruntime-genai\examples\python\whisper.py", line 122, in <module>
    run(args)
  File "C:\Users\LocalAdmin\ort_build\BuildArtifacts-20250707_111001\onnxruntime-genai\examples\python\whisper.py", line 66, in run
    generator.set_inputs(inputs)
RuntimeError: Non-zero status code returned while running Reshape node. Name:'/decoder/Reshape_1' Status Message: C:\Users\LocalAdmin\ort_build\BuildArtifacts-20250707_111001\onnxruntime\onnxruntime\core\providers\cpu\tensor\reshape_helper.h:47 onnxruntime::ReshapeHelper::ReshapeHelper input_shape_size == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{4}, requested shape:{1,1}

Please advise.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions