Skip to content
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def _create_tokenized_object(
sampling_kwargs = obj.sampling_params
sampling_params = SamplingParams(**sampling_kwargs)
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
sampling_params.verify(self.model_config.vocab_size)

# Build return object
if isinstance(obj, GenerateReqInput):
Expand Down
9 changes: 8 additions & 1 deletion python/sglang/srt/sampling/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
if self.top_k == -1:
self.top_k = TOP_K_ALL # whole vocabulary

def verify(self):
def verify(self, vocab_size):
if self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}."
Expand Down Expand Up @@ -131,6 +131,13 @@ def verify(self):
f"min_new_tokens must be in [0, max_new_tokens({self.max_new_tokens})], got "
f"{self.min_new_tokens}."
)
if self.logit_bias is not None:
for token_id in self.logit_bias:
if not 0 <= int(token_id) < vocab_size:
raise ValueError(
f"logit_bias must has keys in [0, {vocab_size - 1}], got "
f"{token_id}."
)
grammars = [
self.json_schema,
self.regex,
Expand Down
Loading