diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index e6f2645a766e..3165a0616c4b 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1069,6 +1069,17 @@ def __init__( offload_only_non_sliding: bool = True, **kwargs, ): + if kwargs: + raise TypeError(f"Unknown arguments passed to StaticCache: {list(kwargs.keys())}") + + if not isinstance(offloading, bool): + raise TypeError( + f"`offloading` must be a bool, got {type(offloading)}. " + "Did you accidentally pass `device` as a positional argument?" + ) + if not isinstance(offload_only_non_sliding, bool): + raise TypeError(f"`offload_only_non_sliding` must be a bool, got {type(offload_only_non_sliding)}.") + config = config.get_text_config(decoder=True) layer_types = getattr(config, "layer_types", None) # If `layer_types` is not explicitly provided, infer if the model is fully sliding diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index b3b03c49f5e3..5734ce8bc7e1 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -937,6 +937,28 @@ def test_static_cache(self): static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed" ) + def test_static_cache_type_checks(self): + """Test that StaticCache validates offloading types and unknown kwargs.""" + cache = StaticCache( + config=self.config, max_cache_len=self.max_cache_len, offloading=True, offload_only_non_sliding=False + ) + self.assertIsInstance(cache, StaticCache) + + # Passing wrong type for offloading should raise TypeError + with self.assertRaises(TypeError) as cm: + StaticCache(config=self.config, max_cache_len=self.max_cache_len, offloading="cuda:0") + self.assertIn("`offloading` must be a bool", str(cm.exception)) + + # Passing wrong type for offload_only_non_sliding should raise TypeError + with self.assertRaises(TypeError) as cm: + StaticCache(config=self.config, max_cache_len=self.max_cache_len, offload_only_non_sliding=1) + self.assertIn("`offload_only_non_sliding` must be a bool", str(cm.exception)) + + # Passing unknown kwargs should raise TypeError + with self.assertRaises(TypeError) as cm: + StaticCache(config=self.config, max_cache_len=self.max_cache_len, foo="bar") + self.assertIn("Unknown arguments passed to StaticCache", str(cm.exception)) + def test_sliding_window_cache(self): """Test fully sliding StaticCache with manually prefilled states and hardcoded assertions.