Skip to content

Conversation

kylesayrs
Copy link
Contributor

@kylesayrs kylesayrs commented Jun 23, 2025

Purpose

Without this change, attempting to CPU offload the encoder layer raises a device error

RuntimeError: Tensor on device meta is not on the expected device cuda:0!

Changes

  • Instead of getting the embed_positions.weight attribute directly, leverage the hf hooks attached to the embed_positions module to onload the weight properly.
    • This induces a small, once per request runtime cost as F.embedding must be called with an identity matrix, rather than grabbing the weight value directly

Testing

Use the following test script to verify that generation works with the device map

device_map={
    "model.encoder": "cpu",
    "model.decoder": 0,
    "proj_out": 0,
},
test_whisper_offload.py
import torch
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor


def load_sample(processor):
    ds = load_dataset(
        "MLCommons/peoples_speech",
        "test", split="test[:1]",
        trust_remote_code=True,
    )

    sample = next(iter(ds))
    sample = processor(
        audio=sample["audio"]["array"],
        sampling_rate=sample["audio"]["sampling_rate"],
        text=(" " + sample["text"].capitalize()),
        add_special_tokens=True,
        return_tensors="pt",
    )

    sample["input_features"] = sample["input_features"].to(dtype=torch.bfloat16)
    sample["decoder_input_ids"] = torch.tensor([processor.tokenizer.prefix_tokens])
    del sample["labels"]

    return sample


if __name__ == "__main__":
    model_id = "openai/whisper-large-v3"
    model = WhisperForConditionalGeneration.from_pretrained(
        model_id,
        device_map={
            "model.encoder": "cpu",
            "model.decoder": 0,
            "proj_out": 0,
        },
        torch_dtype=torch.bfloat16
    )
    processor = WhisperProcessor.from_pretrained(model_id)

    assert model.model.encoder.embed_positions.weight.device == torch.device("meta")
    sample = load_sample(processor)
    output = model.generate(**sample, language="en")
    print(processor.batch_decode(output, skip_special_tokens=True))

Potential Reviewers

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I like that we removed direct access to weight.data. Can you also un-skip offload tests in whisper and make sure they are green?

For ex:

@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
def test_cpu_offload(self):
pass
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
def test_disk_offload_bin(self):
pass
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
def test_disk_offload_safetensors(self):
pass

Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs
Copy link
Contributor Author

kylesayrs commented Jun 24, 2025

@zucchini-nlp @SunMarc Tests unskipped and passing!

Signed-off-by: Kyle Sayers <[email protected]>
@zucchini-nlp
Copy link
Member

run-slow: whisper

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/whisper']
quantizations: [] ...

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kylesayrs
Copy link
Contributor Author

@zucchini-nlp Does this test failure indicate something to fix, or is this test noisy?

First differing element 3:
" Fol[1422 chars]ugitives bug out bindle of news that is my segment. Meanwhile."
" Fol[1422 chars]ugitives bug out bindle of news that is my segment. Meanwhile!"

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kylesayrs Can you rebase/merge? The failing tests are expected, no worries :D

cc @gante @eustlb for viz

@kylesayrs
Copy link
Contributor Author

@vasqu Merged, thank to hear it :)

@vasqu vasqu enabled auto-merge (squash) June 26, 2025 15:43
@vasqu vasqu merged commit 0a8081b into huggingface:main Jun 26, 2025
20 checks passed
@vasqu
Copy link
Contributor

vasqu commented Jun 26, 2025

Thanks @kylesayrs 🤗

zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* fix cpu offloading for whisper

Signed-off-by: Kyle Sayers <[email protected]>

* unskip offloading tests

Signed-off-by: Kyle Sayers <[email protected]>

* revert small change

Signed-off-by: Kyle Sayers <[email protected]>

* remove tests

Signed-off-by: Kyle Sayers <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* fix cpu offloading for whisper

Signed-off-by: Kyle Sayers <[email protected]>

* unskip offloading tests

Signed-off-by: Kyle Sayers <[email protected]>

* revert small change

Signed-off-by: Kyle Sayers <[email protected]>

* remove tests

Signed-off-by: Kyle Sayers <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* fix cpu offloading for whisper

Signed-off-by: Kyle Sayers <[email protected]>

* unskip offloading tests

Signed-off-by: Kyle Sayers <[email protected]>

* revert small change

Signed-off-by: Kyle Sayers <[email protected]>

* remove tests

Signed-off-by: Kyle Sayers <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* fix cpu offloading for whisper

Signed-off-by: Kyle Sayers <[email protected]>

* unskip offloading tests

Signed-off-by: Kyle Sayers <[email protected]>

* revert small change

Signed-off-by: Kyle Sayers <[email protected]>

* remove tests

Signed-off-by: Kyle Sayers <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* fix cpu offloading for whisper

Signed-off-by: Kyle Sayers <[email protected]>

* unskip offloading tests

Signed-off-by: Kyle Sayers <[email protected]>

* revert small change

Signed-off-by: Kyle Sayers <[email protected]>

* remove tests

Signed-off-by: Kyle Sayers <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* fix cpu offloading for whisper

Signed-off-by: Kyle Sayers <[email protected]>

* unskip offloading tests

Signed-off-by: Kyle Sayers <[email protected]>

* revert small change

Signed-off-by: Kyle Sayers <[email protected]>

* remove tests

Signed-off-by: Kyle Sayers <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* fix cpu offloading for whisper

Signed-off-by: Kyle Sayers <[email protected]>

* unskip offloading tests

Signed-off-by: Kyle Sayers <[email protected]>

* revert small change

Signed-off-by: Kyle Sayers <[email protected]>

* remove tests

Signed-off-by: Kyle Sayers <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants