3-router调度请求进行推理

NOTE

前两节我们讲了进程的关系和每个进程的作用,和用户发送请求后的http流转流程

这一节我们接着从请求的视角,来讲router调度请求和实际执行推理

Router 是 nano-sglang 框架中的核心调度组件,负责接收来自 Tokenizer 的请求队列,管理内存池,调度 Prefill 和 Decode 操作,并将生成的 token 发送给 Detokenizer。

核心调度流程

请求接收与转换

Router 通过 exposed_step 方法接收来自 Tokenizer 的请求:

1
2
3
4
5
6
7
8
9
def exposed_step(self, recv_reqs):
    """被异步包装的函数,recv_reqs就是zmq接收到的队列,异步执行当前的step调用推理"""
    # 把请求从TokenizedGenerateReqInput格式转换成推理中的Req类型,并做一些初始化操作
    for recv_req in recv_reqs:
        if isinstance(recv_req, TokenizedGenerateReqInput):
            self.add_request(recv_req)

    # Forward # 实际推理的位置
    self.forward_step()

推理主流程 (forward_step)

大模型的推理分为两个阶段:Prefill 和 Decode

在sglang中,依然是优先处理prefill的请求,再处理decode 请求。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def forward_step(self):
    # 对新到的序列做管理,由于空间有限,只能推理有限的序列
    new_batch = self.get_new_fill_batch()

    # 有新的序列需要做prefill,prefill优先
    if new_batch is not None:
        # Run new fill batch
        self.forward_fill_batch(new_batch)

        # 上面的new_batch在做完forward后,留下了还需要继续做decode的内容
        # 和原来的runing_batch放到一起做decode
        if not new_batch.is_empty():
            if self.running_batch is None:
                self.running_batch = new_batch
            else:
                self.running_batch.merge(new_batch)

    # Run decode batch
    if self.running_batch is not None:
        self.forward_decode_batch(self.running_batch)

完整函数调用流程

以下是一个完整的 step 过程中的函数调用链:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
1. exposed_step(recv_reqs)  [ModelRpcServer]
   └── 接收来自 Tokenizer 的请求队列
   ├── add_request(recv_req)  [转换 TokenizedGenerateReqInput → Req]
   │   ├── forward_queue.append(req)
   │   └── 初始化请求参数
   └── forward_step()  [主推理流程]

2. forward_step()  [ModelRpcServer]
   ├── get_new_fill_batch()  [构建 Prefill 批次]
   │   ├── 检查 running_batch 是否已满
   │   ├── 遍历 forward_queue 中的请求
   │   │   ├── tree_cache.match_prefix()  [前缀匹配]
   │   │   ├── 计算 adjust_input_len
   │   │   └── 设置 prefix_indices, last_node
   │   ├── scheduler.get_priority_queue()  [请求排序]
   │   │   └── 按 schedule_heuristic 排序 (默认 LPM)
   │   ├── 计算可用内存空间
   │   │   ├── token_to_kv_pool.available_size()
   │   │   ├── tree_cache.evictable_size()
   │   │   └── 减去 running_batch 的预估内存
   │   ├── 遍历请求进行筛选
   │   │   ├── 检查内存限制
   │   │   ├── 检查 max_prefill_num_token 限制
   │   │   ├── tree_cache.inc_ref_counter()  [尝试复用 KV 缓存]
   │   │   └── token_to_kv_pool.add_refs()
   │   └── 创建 Batch 对象
   │       ├── Batch()
   │       └── 从 forward_queue 中移除已添加的请求

   ├── forward_fill_batch(new_batch)  [Prefill 推理]
   │   ├── batch.init_extend_batch()
   │   │   ├── 准备 GPU 张量
   │   │   └── 设置 positions 等参数
   │   ├── model_runner.forward(batch, ForwardMode.EXTEND)
   │   │   └── 调用具体模型 (Qwen/Llama) 的 forward
   │   ├── batch.sample(logits)
   │   │   ├── top_k/top_p 采样
   │   │   └── 获取 next_token_ids
   │   ├── 遍历请求检查完成条件
   │   │   ├── req.output_ids = [next_token_ids[i]]
   │   │   └── req.check_finish()
   │   └── handle_finished_requests(batch)
   │       └── 将完成的请求发送到 Detokenizer

   └── forward_decode_batch(running_batch)  [Decode 推理]
       ├── batch.update_for_decode()
       │   └── 准备 decode 张量
       ├── model_runner.forward(batch, ForwardMode.DECODE)
       │   └── 调用模型进行 decode
       ├── batch.sample(logits)
       │   └── 采样获取下一个 token
       ├── 处理输出
       │   ├── 构建 BatchTokenIDOut
       │   └── out_pyobjs.append()
       └── 清理完成的请求
           ├── tree_cache.dec_ref_counter()
           ├── batch.filter_batch()
           └── 更新 running_batch

关键数据结构转换流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
TokenizedGenerateReqInput [来自 Tokenizer]

add_request()

Req [Router 内部格式]
    ├── input_ids: List[int]
    ├── prefix_indices: List[int] [匹配的前缀]
    ├── adjust_input_len: int [需要推理的长度]
    ├── last_node: TreeNode [KV 缓存节点]
    └── max_new_tokens: int [最大生成长度]

Batch [推理批次]
    ├── reqs: List[Req]
    ├── input_ids: Tensor
    ├── positions: Tensor
    └── 其他 GPU 张量

BatchTokenIDOut [发送到 Detokenizer]
    ├── req_ids: List[str]
    ├── token_ids: List[List[int]]
    └── finished: List[bool]

批次生命周期

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
1. 创建阶段
   └── get_new_fill_batch()
       ├── 从 forward_queue 选择请求
       ├── 创建 Batch 对象
       └── 初始化批次张量

2. Prefill 阶段
   └── forward_fill_batch()
       ├── batch.init_extend_batch()
       ├── 执行模型推理
       └── 生成第一个 token

3. Decode 阶段
   └── forward_decode_batch()
       ├── batch.update_for_decode()
       ├── 循环执行 decode
       └── 逐个生成 token

4. 完成阶段
   ├── 发送结果到 Detokenizer
   ├── 释放 KV 缓存
   └── 从批次中移除完成的请求