sglang处理function_call

无论是基于提示词还是基于api的function call,本质上都是token的处理

  • 基于提示词的function call 是把结构化的输出要求放到system prompts中,再对回复做function call的正则匹配
  • 基于api的function call 则是交给推理框架处理,请求时带上对应的tools字段,推理框架会把tools的内容做tokenizer和prompts放到一起,总之输入肯定也就是tokens,对输出则是推理框架去通过正则匹配function call的结构化输出,匹配上了就认为是function call的调用,返回响应finish_reason对应为function_call,如果没匹配上,就认为是普通文本输出

NOTE

下面将以sglang+qwen+非流式讲一讲sglang是怎么处理parse function call的请求和返回的

首先是sglang启动qwen模型有加上 --tool-call-parse qwen

qwen适用于除了qwen3 coder的所有模型,如果是qwen3 coder模型,则应该是--tool-call-parse qwen3_coder

serving_chat.py中处理返回给用户的响应:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# 处理入口
def _build_chat_response(
    self,
    request: ChatCompletionRequest,
    ret: List[Dict[str, Any]],
    created: int,
) -> Union[ChatCompletionResponse, ORJSONResponse]:
    """Build chat completion response from generation results"""
    choices = []

    for idx, ret_item in enumerate(ret):
        finish_reason = ret_item["meta_info"]["finish_reason"]
        text = ret_item["text"]

        # ......

        # Handle tool calls 这里处理function call
        tool_calls = None
        if (
            request.tool_choice != "none"
            and request.tools
            and self.tool_call_parser
        ):
            history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
            # 下面的函数进行实际的匹配请求
            tool_calls, text, finish_reason = self._process_tool_calls(
                text,
                request.tools,
                finish_reason,
                request.tool_choice,
                history_tool_calls_cnt,
            )

        # ......
    # ......

    return ChatCompletionResponse(
        id=ret[0]["meta_info"]["id"],
        created=created,
        model=request.model,
        choices=choices,
        usage=usage,
        metadata={"weight_version": ret[0]["meta_info"]["weight_version"]},
    )

def _process_tool_calls(
    self,
    text: str,
    tools: List[Any],
    finish_reason: Dict[str, Any],
    tool_choice: Optional[Union[str, ToolChoice]] = None,
    history_tool_calls_cnt: int = 0,
) -> ToolCallProcessingResult:
    """Process tool calls in the response"""

    # Handle required or named tool choice
    # ......

    # 这里调用实际的处理,不同模型对应着不同的处理细节方式
    # Use parser since output is not constrained by JSON schema
    parser = FunctionCallParser(tools, self.tool_call_parser)
    if parser.has_tool_call(text):
        if finish_reason["type"] == "stop":
            finish_reason["type"] = "tool_calls"
            finish_reason["matched"] = None
        try:
            text, call_info_list = parser.parse_non_stream(text)
            tool_calls = []
            for call_info in call_info_list:
                tool_id = self._process_tool_call_id(
                    call_info, history_tool_calls_cnt
                )
                tool_calls.append(
                    ToolCall(
                        id=tool_id,
                        index=getattr(call_info, "tool_index", None),
                        function=FunctionResponse(
                            name=call_info.name, arguments=call_info.parameters
                        ),
                    )
                )
            return ToolCallProcessingResult(tool_calls, text, finish_reason)
        except Exception as e:
            logger.error(f"Tool call parsing error: {e}")
            # Return error but don't fail the whole request
            return ToolCallProcessingResult(None, text, finish_reason)

    return ToolCallProcessingResult(None, text, finish_reason)

也就是:

  1. 首先是要parser.has_tool_call(text)==True,对于qwen来说,就是文本里要有 "<tool_call>\n"

  2. 然后调用parser.parse_non_stream(text)进行匹配,这个会调用self.detector.detect_and_parse(full_text, self.tools)

    当然这里是以非流式距离,流式的话就是不断拼接输出,动态处理,非流就是生成结束后一次处理

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
        """
        One-time parsing: Detects and parses tool calls in the provided text.
    
        :param text: The complete text to parse.
        :param tools: List of available tools.
        :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
        """
        idx = text.find(self.bot_token)
        normal_text = text[:idx].strip() if idx != -1 else text
        if self.bot_token not in text:
            return StreamingParseResult(normal_text=normal_text, calls=[])
    
        # 这里通过正则表达式匹配格式
        # Find all <tool_call>\n...\n</tool_call> blocks
        pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}"
        match_result_list = re.findall(pattern, text, re.DOTALL)
        calls = []
        # 可能输出包含多次function call,所以匹配出来是个列表
        for match_result in match_result_list:
            try:
                parsed_call = json.loads(match_result.strip())
                calls.extend(self.parse_base_json(parsed_call, tools))
            except json.JSONDecodeError as e:
                logger.warning(
                    f"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}"
                )
                continue
         # 如果没有匹配上,那么就是返回StreamingParseResult(normal_text=normal_text, calls=[])
        return StreamingParseResult(normal_text=normal_text, calls=calls)
    
    class ToolCallItem(BaseModel):
     """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
    
     tool_index: int
     name: Optional[str] = None
     parameters: str  # JSON string
    
    
     class StreamingParseResult(BaseModel):
         """Result of streaming incremental parsing."""
    
         normal_text: str = ""
         calls: List[ToolCallItem] = []