Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,17 @@ def __init__(
offload_only_non_sliding: bool = True,
**kwargs,
):
if kwargs:
raise TypeError(f"Unknown arguments passed to StaticCache: {list(kwargs.keys())}")

if not isinstance(offloading, bool):
raise TypeError(
f"`offloading` must be a bool, got {type(offloading)}. "
"Did you accidentally pass `device` as a positional argument?"
)
if not isinstance(offload_only_non_sliding, bool):
raise TypeError(f"`offload_only_non_sliding` must be a bool, got {type(offload_only_non_sliding)}.")

config = config.get_text_config(decoder=True)
layer_types = getattr(config, "layer_types", None)
# If `layer_types` is not explicitly provided, infer if the model is fully sliding
Expand Down
22 changes: 22 additions & 0 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,28 @@ def test_static_cache(self):
static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
)

def test_static_cache_type_checks(self):
"""Test that StaticCache validates offloading types and unknown kwargs."""
cache = StaticCache(
config=self.config, max_cache_len=self.max_cache_len, offloading=True, offload_only_non_sliding=False
)
self.assertIsInstance(cache, StaticCache)

# Passing wrong type for offloading should raise TypeError
with self.assertRaises(TypeError) as cm:
StaticCache(config=self.config, max_cache_len=self.max_cache_len, offloading="cuda:0")
self.assertIn("`offloading` must be a bool", str(cm.exception))

# Passing wrong type for offload_only_non_sliding should raise TypeError
with self.assertRaises(TypeError) as cm:
StaticCache(config=self.config, max_cache_len=self.max_cache_len, offload_only_non_sliding=1)
self.assertIn("`offload_only_non_sliding` must be a bool", str(cm.exception))

# Passing unknown kwargs should raise TypeError
with self.assertRaises(TypeError) as cm:
StaticCache(config=self.config, max_cache_len=self.max_cache_len, foo="bar")
self.assertIn("Unknown arguments passed to StaticCache", str(cm.exception))

def test_sliding_window_cache(self):
"""Test fully sliding StaticCache with manually prefilled states and hardcoded assertions.

Expand Down