Skip to content

Commit b58056b

Browse files
kunal-vaishnavijeffkilpatrick
authored andcommitted
Fix number of layers in Whisper export (microsoft#25375)
### Description This PR fixes the number of hidden layers used during the export of Whisper by always using the number of hidden layers in the decoder. ### Motivation and Context Most of the Whisper models contain the same number of hidden layers in the encoder and decoder. However, Whisper large v3 turbo contains 32 hidden layers in the encoder and only 4 hidden layers in the decoder. This PR also fixes [this issue](microsoft/onnxruntime-genai#1611).
1 parent 6b740de commit b58056b

File tree

7 files changed

+13
-14
lines changed

7 files changed

+13
-14
lines changed

onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def export_onnx_models(
410410
precision == Precision.FLOAT16,
411411
model.config.encoder_attention_heads,
412412
model.config.d_model,
413-
model.config.num_hidden_layers,
413+
model.config.decoder_layers,
414414
use_external_data_format,
415415
use_gpu=use_gpu,
416416
provider=provider,

onnxruntime/python/tools/transformers/models/whisper/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
torch>=2.7.0
2-
transformers>=4.52.3
2+
transformers==4.52.3
33
openai-whisper==20240927
44
ffmpeg-python
55
datasets

onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def input_names(self):
187187
*list(
188188
chain.from_iterable(
189189
(f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}")
190-
for i in range(self.config.num_hidden_layers)
190+
for i in range(self.config.decoder_layers)
191191
)
192192
),
193193
]
@@ -205,7 +205,7 @@ def output_names(self):
205205
f"present_key_cross_{i}",
206206
f"present_value_cross_{i}",
207207
)
208-
for i in range(self.config.num_hidden_layers)
208+
for i in range(self.config.decoder_layers)
209209
)
210210
),
211211
]
@@ -214,8 +214,7 @@ def output_names(self):
214214
"logits",
215215
*list(
216216
chain.from_iterable(
217-
(f"present_key_self_{i}", f"present_value_self_{i}")
218-
for i in range(self.config.num_hidden_layers)
217+
(f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers)
219218
)
220219
),
221220
]

onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def output_names(self):
127127
*list(
128128
chain.from_iterable(
129129
(f"present_key_cross_{i}", f"present_value_cross_{i}")
130-
for i in range(self.config.num_hidden_layers)
130+
for i in range(self.config.decoder_layers)
131131
)
132132
),
133133
]
@@ -143,7 +143,7 @@ def output_names(self):
143143
f"present_key_cross_{i}",
144144
f"present_value_cross_{i}",
145145
)
146-
for i in range(self.config.num_hidden_layers)
146+
for i in range(self.config.decoder_layers)
147147
)
148148
),
149149
]

onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ def optimize_onnx(
763763
is_float16: bool,
764764
num_attention_heads: int,
765765
hidden_size: int,
766-
num_layers: int,
766+
num_decoder_layers: int,
767767
use_external_data_format: bool = False,
768768
use_gpu: bool = False,
769769
provider: str = "cpu",
@@ -801,7 +801,7 @@ def optimize_onnx(
801801
m = add_cache_indirection_to_mha(m, past_seq_len_name)
802802

803803
if output_qk:
804-
m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_layers, 2)))
804+
m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_decoder_layers, 2)))
805805

806806
m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
807807

onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,14 @@ def get_sample_past_key_values(
9494
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
9595
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
9696
)
97-
for _ in range(config.num_hidden_layers)
97+
for _ in range(config.decoder_layers)
9898
]
9999
cross_attention_kv_caches = [
100100
(
101101
torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
102102
torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
103103
)
104-
for _ in range(config.num_hidden_layers)
104+
for _ in range(config.decoder_layers)
105105
]
106106
return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches)
107107

@@ -187,7 +187,7 @@ def get_sample_QKs( # noqa: N802
187187
torch.rand(
188188
batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype
189189
)
190-
for _ in range(config.num_hidden_layers)
190+
for _ in range(config.decoder_layers)
191191
]
192192
return QKs
193193

onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def input_names(self):
156156
"alignment_heads",
157157
"sot_sequence_length",
158158
"segment_length",
159-
*[f"cross_qk_{i}" for i in range(self.config.num_hidden_layers)],
159+
*[f"cross_qk_{i}" for i in range(self.config.decoder_layers)],
160160
]
161161
return input_names
162162

0 commit comments

Comments
 (0)