-
Notifications
You must be signed in to change notification settings - Fork 30.6k
Closed
Labels
Description
This check creates some issues with torch.compile. The type hint is bool, but in some cases, that offloading value is actually a cuda device.
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
For example, with Qwen3-omni, if you do this, offloading is a cuda device, which triggers that offloading check, which crashes torch.compile:
from transformers import StaticCache
past_key_values = StaticCache(model.thinker.config, max_model_len, device, compute_dtype)
print(past_key_values.offloading)
# `cuda:0`- should be False!
Expected behavior
Static cache should not crash with torch.compile.