-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Preliminary Support for Qwen3XMLDetector #8260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import ast | ||
import html | ||
import json | ||
import logging | ||
import re | ||
from typing import Any, Dict, List, Tuple | ||
|
||
from sglang.srt.entrypoints.openai.protocol import Tool | ||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector | ||
from sglang.srt.function_call.core_types import ( | ||
StreamingParseResult, | ||
StructureInfo, | ||
ToolCallItem, | ||
_GetInfoFunc, | ||
) | ||
from sglang.srt.function_call.ebnf_composer import EBNFComposer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _safe_val(raw: str) -> Any: | ||
raw = html.unescape(raw.strip()) | ||
try: | ||
return json.loads(raw) | ||
except Exception: | ||
try: | ||
return ast.literal_eval(raw) | ||
except Exception: | ||
return raw | ||
|
||
|
||
class Qwen3XMLDetector(BaseFormatDetector): | ||
""" | ||
Detector for Qwen 3 models. | ||
Assumes function call format: | ||
<tool_call> | ||
<function=execute_bash> | ||
<parameter=command> | ||
pwd && ls | ||
</parameter> | ||
</function> | ||
</tool_call> | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.tool_call_start_token: str = "<tool_call>" | ||
self.tool_call_end_token: str = "</tool_call>" | ||
self.tool_call_prefix: str = "<function=" | ||
self.tool_call_regex = re.compile( | ||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL | ||
) | ||
Comment on lines
+50
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
self.tool_call_function_regex = re.compile( | ||
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL | ||
) | ||
self.tool_call_parameter_regex = re.compile( | ||
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL | ||
) | ||
self._buf: str = "" | ||
|
||
def has_tool_call(self, text: str) -> bool: | ||
return self.tool_call_start_token in text | ||
|
||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: | ||
normal, calls = self._extract(text, tools) | ||
return StreamingParseResult(normal_text=normal, calls=calls) | ||
|
||
def parse_streaming_increment( | ||
self, new_text: str, tools: List[Tool] | ||
) -> StreamingParseResult: | ||
self._buf += new_text | ||
normal = "" | ||
calls: List[ToolCallItem] = [] | ||
while True: | ||
if self.tool_call_start_token not in self._buf: | ||
normal += self._buf | ||
self._buf = "" | ||
break | ||
s = self._buf.find(self.tool_call_start_token) | ||
if s > 0: | ||
normal += self._buf[:s] | ||
self._buf = self._buf[s:] | ||
e = self._buf.find(self.tool_call_end_token) | ||
if e == -1: | ||
break | ||
block = self._buf[: e + len(self.tool_call_end_token)] | ||
self._buf = self._buf[e + len(self.tool_call_end_token) :] | ||
calls.extend(self._parse_block(block, tools)) | ||
return StreamingParseResult(normal_text=normal, calls=calls) | ||
|
||
def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]: | ||
normal_parts: List[str] = [] | ||
calls: List[ToolCallItem] = [] | ||
cursor = 0 | ||
while True: | ||
s = text.find(self.tool_call_start_token, cursor) | ||
if s == -1: | ||
normal_parts.append(text[cursor:]) | ||
break | ||
normal_parts.append(text[cursor:s]) | ||
e = text.find(self.tool_call_end_token, s) | ||
if e == -1: | ||
normal_parts.append(text[s:]) | ||
break | ||
block = text[s : e + len(self.tool_call_end_token)] | ||
cursor = e + len(self.tool_call_end_token) | ||
calls.extend(self._parse_block(block, tools)) | ||
return "".join(normal_parts), calls | ||
|
||
def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]: | ||
res: List[ToolCallItem] = [] | ||
for m in self.tool_call_function_regex.findall(block): | ||
txt = m[0] if m[0] else m[1] | ||
if ">" not in txt: | ||
continue | ||
idx = txt.index(">") | ||
fname = txt[:idx].strip() | ||
body = txt[idx + 1 :] | ||
params: Dict[str, Any] = {} | ||
for pm in self.tool_call_parameter_regex.findall(body): | ||
ptxt = pm[0] if pm[0] else pm[1] | ||
if ">" not in ptxt: | ||
continue | ||
pidx = ptxt.index(">") | ||
pname = ptxt[:pidx].strip() | ||
pval = ptxt[pidx + 1 :].lstrip("\n").rstrip("\n") | ||
params[pname] = _safe_val(pval) | ||
raw = {"name": fname, "arguments": params} | ||
try: | ||
res.extend(self.parse_base_json(raw, tools)) | ||
except Exception: | ||
logger.warning("invalid tool call for %s dropped", fname) | ||
Comment on lines
+131
to
+132
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The except Exception as e:
logger.warning("invalid tool call for %s dropped: %s", fname, e) |
||
return res | ||
|
||
def structure_info(self) -> _GetInfoFunc: | ||
return lambda n: StructureInfo( | ||
begin=f"{self.tool_call_start_token}\n<function={n}>", | ||
end=f"</function>\n{self.tool_call_end_token}", | ||
trigger=self.tool_call_start_token, | ||
) | ||
|
||
# TODO: fake ebnf for xml + outlines backend | ||
def build_ebnf(self, tools: List[Tool]): | ||
return EBNFComposer.build_ebnf( | ||
tools, | ||
individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"), | ||
individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"), | ||
tool_call_separator="\\n", | ||
function_format="json", | ||
) | ||
Comment on lines
+142
to
+150
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The While the PR description mentions this is a known issue, it's a critical part of the functionality. A proper EBNF grammar for the XML-like syntax needs to be constructed. This would likely involve creating a new EBNF template for this format, rather than reusing the JSON one. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1099,6 +1099,7 @@ def add_cli_args(parser: argparse.ArgumentParser): | |
"deepseekv3", | ||
"pythonic", | ||
"kimi_k2", | ||
"qwen3", | ||
], | ||
Comment on lines
+1102
to
1103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The help text for help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3'.", |
||
default=ServerArgs.tool_call_parser, | ||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
except Exception:
clauses are very broad. It's better to catch more specific exceptions to avoid masking unexpected errors. Forjson.loads
, you should catchjson.JSONDecodeError
. Forast.literal_eval
, you should catchValueError
andSyntaxError
. This makes the error handling more robust and prevents hiding potential bugs.