diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 81f36faa616..38a8fa53af7 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -604,7 +604,7 @@ def _create_tokenized_object( sampling_kwargs = obj.sampling_params sampling_params = SamplingParams(**sampling_kwargs) sampling_params.normalize(self.tokenizer) - sampling_params.verify() + sampling_params.verify(self.model_config.vocab_size) # Build return object if isinstance(obj, GenerateReqInput): diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index e0a88107fff..b7d1a6d6e2c 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -89,7 +89,7 @@ def __init__( if self.top_k == -1: self.top_k = TOP_K_ALL # whole vocabulary - def verify(self): + def verify(self, vocab_size): if self.temperature < 0.0: raise ValueError( f"temperature must be non-negative, got {self.temperature}." @@ -131,6 +131,13 @@ def verify(self): f"min_new_tokens must be in [0, max_new_tokens({self.max_new_tokens})], got " f"{self.min_new_tokens}." ) + if self.logit_bias is not None: + for token_id in self.logit_bias: + if not 0 <= int(token_id) < vocab_size: + raise ValueError( + f"logit_bias must has keys in [0, {vocab_size - 1}], got " + f"{token_id}." + ) grammars = [ self.json_schema, self.regex,