diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index a6708024f87..4c38d9d4fb0 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -14,6 +14,7 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector +from sglang.srt.function_call.qwen3_detector import Qwen3XMLDetector from sglang.srt.function_call.qwen25_detector import Qwen25Detector logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ class FunctionCallParser: "deepseekv3": DeepSeekV3Detector, "pythonic": PythonicDetector, "kimi_k2": KimiK2Detector, + "qwen3": Qwen3XMLDetector, } def __init__(self, tools: List[Tool], tool_call_parser: str): diff --git a/python/sglang/srt/function_call/qwen3_detector.py b/python/sglang/srt/function_call/qwen3_detector.py new file mode 100644 index 00000000000..5c6ac698e8e --- /dev/null +++ b/python/sglang/srt/function_call/qwen3_detector.py @@ -0,0 +1,150 @@ +import ast +import html +import json +import logging +import re +from typing import Any, Dict, List, Tuple + +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + StructureInfo, + ToolCallItem, + _GetInfoFunc, +) +from sglang.srt.function_call.ebnf_composer import EBNFComposer + +logger = logging.getLogger(__name__) + + +def _safe_val(raw: str) -> Any: + raw = html.unescape(raw.strip()) + try: + return json.loads(raw) + except Exception: + try: + return ast.literal_eval(raw) + except Exception: + return raw + + +class Qwen3XMLDetector(BaseFormatDetector): + """ + Detector for Qwen 3 models. + Assumes function call format: + + + + pwd && ls + + + + """ + + def __init__(self): + super().__init__() + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.tool_call_prefix: str = "(.*?)|(.*?)$", re.DOTALL + ) + self.tool_call_function_regex = re.compile( + r"|| bool: + return self.tool_call_start_token in text + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + normal, calls = self._extract(text, tools) + return StreamingParseResult(normal_text=normal, calls=calls) + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + self._buf += new_text + normal = "" + calls: List[ToolCallItem] = [] + while True: + if self.tool_call_start_token not in self._buf: + normal += self._buf + self._buf = "" + break + s = self._buf.find(self.tool_call_start_token) + if s > 0: + normal += self._buf[:s] + self._buf = self._buf[s:] + e = self._buf.find(self.tool_call_end_token) + if e == -1: + break + block = self._buf[: e + len(self.tool_call_end_token)] + self._buf = self._buf[e + len(self.tool_call_end_token) :] + calls.extend(self._parse_block(block, tools)) + return StreamingParseResult(normal_text=normal, calls=calls) + + def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]: + normal_parts: List[str] = [] + calls: List[ToolCallItem] = [] + cursor = 0 + while True: + s = text.find(self.tool_call_start_token, cursor) + if s == -1: + normal_parts.append(text[cursor:]) + break + normal_parts.append(text[cursor:s]) + e = text.find(self.tool_call_end_token, s) + if e == -1: + normal_parts.append(text[s:]) + break + block = text[s : e + len(self.tool_call_end_token)] + cursor = e + len(self.tool_call_end_token) + calls.extend(self._parse_block(block, tools)) + return "".join(normal_parts), calls + + def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]: + res: List[ToolCallItem] = [] + for m in self.tool_call_function_regex.findall(block): + txt = m[0] if m[0] else m[1] + if ">" not in txt: + continue + idx = txt.index(">") + fname = txt[:idx].strip() + body = txt[idx + 1 :] + params: Dict[str, Any] = {} + for pm in self.tool_call_parameter_regex.findall(body): + ptxt = pm[0] if pm[0] else pm[1] + if ">" not in ptxt: + continue + pidx = ptxt.index(">") + pname = ptxt[:pidx].strip() + pval = ptxt[pidx + 1 :].lstrip("\n").rstrip("\n") + params[pname] = _safe_val(pval) + raw = {"name": fname, "arguments": params} + try: + res.extend(self.parse_base_json(raw, tools)) + except Exception: + logger.warning("invalid tool call for %s dropped", fname) + return res + + def structure_info(self) -> _GetInfoFunc: + return lambda n: StructureInfo( + begin=f"{self.tool_call_start_token}\n", + end=f"\n{self.tool_call_end_token}", + trigger=self.tool_call_start_token, + ) + + # TODO: fake ebnf for xml + outlines backend + def build_ebnf(self, tools: List[Tool]): + return EBNFComposer.build_ebnf( + tools, + individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"), + individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"), + tool_call_separator="\\n", + function_format="json", + ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6464f9f40a3..400a1bf99e8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1099,6 +1099,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "deepseekv3", "pythonic", "kimi_k2", + "qwen3", ], default=ServerArgs.tool_call_parser, help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.",