Skip to content

Commit f8ca236

Browse files
zhyncscicirorigongwei-130
authored
fix: kimi k2 xgrammar crash (#8367)
Co-authored-by: cicirori <[email protected]> Co-authored-by: gongwei-130 <[email protected]>
1 parent d8ee156 commit f8ca236

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

python/sglang/srt/managers/schedule_batch.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ def __init__(
431431
bootstrap_port: Optional[int] = None,
432432
bootstrap_room: Optional[int] = None,
433433
data_parallel_rank: Optional[int] = None,
434+
vocab_size: Optional[int] = None,
434435
):
435436
# Input and output info
436437
self.rid = rid
@@ -480,6 +481,7 @@ def __init__(
480481
self.to_abort_message: str = None
481482
self.stream = stream
482483
self.eos_token_ids = eos_token_ids
484+
self.vocab_size = vocab_size
483485

484486
# For incremental decoding
485487
# ----- | --------- read_ids -------|
@@ -713,6 +715,14 @@ def check_finished(self):
713715
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
714716
return
715717

718+
if last_token_id > self.vocab_size or last_token_id < 0:
719+
if self.sampling_params.stop_token_ids:
720+
self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
721+
if self.eos_token_ids:
722+
self.output_ids[-1] = next(iter(self.eos_token_ids))
723+
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
724+
return
725+
716726
# Check stop strings
717727
if len(self.sampling_params.stop_strs) > 0:
718728
tail_str = self.tokenizer.decode(

python/sglang/srt/managers/scheduler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,7 @@ def handle_generate_request(
11291129
bootstrap_port=recv_req.bootstrap_port,
11301130
bootstrap_room=recv_req.bootstrap_room,
11311131
data_parallel_rank=recv_req.data_parallel_rank,
1132+
vocab_size=self.model_config.vocab_size,
11321133
)
11331134
req.tokenizer = self.tokenizer
11341135

@@ -1395,8 +1396,10 @@ def log_prefill_stats(
13951396
logger.info(f)
13961397

13971398
if self.enable_metrics:
1398-
cache_hit_rate = adder.log_hit_tokens / (
1399-
adder.log_input_tokens + adder.log_hit_tokens
1399+
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
1400+
1401+
cache_hit_rate = (
1402+
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
14001403
)
14011404
self.stats.num_running_reqs = running_bs
14021405
self.stats.num_used_tokens = num_used

0 commit comments

Comments
 (0)