2-流式响应架构

TokenizerManager 流式响应架构详解

总体概览:

sequenceDiagram
    participant User as 用户
    participant API as FastAPI Server
    participant TM as TokenizerManager
    participant Router as Router进程
    participant Model as Model RPC
    participant Detok as Detokenizer

    User->>API: POST /generate (GenerateReqInput)
    API->>API: obj.post_init()
    API->>TM: generate_request(obj)

    TM->>TM: 第一次请求创建handle_loop
    TM->>TM: tokenizer.encode(text)
    TM->>TM: SamplingParams处理
    TM->>TM: 图像处理(如果有)
    TM->>TM: 创建TokenizedGenerateReqInput

    TM->>Router: send_pyobj(tokenized_obj)

    Note over Router: 请求调度和批处理
    Router->>Model: RPyC调用模型推理

    loop 模型推理
        Model->>Model: 前向推理生成tokens
    end

    Model->>Router: 返回生成的token IDs
    Router->>Detok: 发送BatchTokenIDOut

    Note over Detok: Token到文本转换
    Detok->>TM: 发送BatchStrOut

    par handle_loop处理
        TM->>TM: recv_pyobj()接收结果
        TM->>TM: 存储到state.out_list
        TM->>TM: state.event.set()通知
    end

    par generate_request等待
        TM->>TM: await event.wait()
        TM->>TM: yield state.out_list[-1]
        TM->>API: 异步生成器产生结果
    end

    API->>API: 流式或非流式响应
    API->>User: 返回最终结果

    Note over User,Detok: 流式输出时重复上述循环,直到finished=true

请求处理:支持单个和多个请求

请求入口与数据结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# server.py:40-53
@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
    obj.post_init()
    result_generator = tokenizer_manager.generate_request(obj)

    if obj.stream:
        async def stream_results():
            async for out in result_generator:
                yield (json.dumps(out) + "\0").encode("utf-8")
        return StreamingResponse(stream_results(), media_type="text/event-stream")
    else:
        ret = await result_generator.__anext__()
        return ret

GenerateReqInput 数据结构

1
2
3
4
5
6
7
8
9
10
# io_struct.py:9-59
@dataclass
class GenerateReqInput:
    text: Union[List[str], str]           # 支持单个字符串或字符串列表
    image_data: Optional[Union[List[str], str]] = None
    sampling_params: Union[List[Dict], Dict] = None
    rid: Optional[Union[List[str], str]] = None
    return_normalized_logprob: Optional[Union[List[bool], bool]] = None
    normalized_logprob_start_len: Optional[Union[List[int], int]] = None
    stream: bool = False

关键特性:

  • 灵活的输入格式text 可以是单个字符串(单个请求)或字符串列表(批量请求)
  • 批量处理一致性:对于多个请求,返回值也是多个一起返回,保持请求与响应的对应关系
  • 自动 RID 分配:如果没有提供 rid,系统会自动为每个请求生成唯一标识符
1
2
3
4
5
# io_struct.py:25, 44
if self.rid is None:
    self.rid = uuid.uuid4().hex          # 单个请求
else:
    self.rid = [uuid.uuid4().hex for _ in range(num)]  # 批量请求

全局唯一的 handle_loop

延迟初始化机制

1
2
3
4
5
6
7
8
9
# tokenizer_manager.py:113-114, 198-201
async def generate_request(self, obj: GenerateReqInput):
    if self.to_create_loop:              # 只有第一次请求时为 True
        await self.create_handle_loop()  # 创建唯一的后台循环

async def create_handle_loop(self):
    self.to_create_loop = False          # 创建后立即设为 False
    loop = asyncio.get_event_loop()
    loop.create_task(self.handle_loop()) # 启动全局唯一循环

整个 TokenizerManager 实例只有一个 handle_loop 循环,第一次请求时才创建,避免不必要的资源消耗

循环的生命周期

1
2
3
4
5
6
7
8
9
10
11
# tokenizer_manager.py:203-219
async def handle_loop(self):
    while True:                          # ♾️ 永久运行的后台循环
        recv_obj = await self.recv_from_detokenizer.recv_pyobj()
        if isinstance(recv_obj, BatchStrOut):
            for i, rid in enumerate(recv_obj.rids):
                # 处理多个请求的响应...
                state = self.rid_to_state[rid]
                state.out_list.append(out_dict)
                state.finished = recv_obj.finished[i]
                state.event.set()         # 唤醒对应的 generate_request

一个循环处理所有请求的响应,在 await recv_from_detokenizer.recv_pyobj() 处阻塞,不消耗 CPU

rid_to_state:请求状态管理中心

ReqState 数据结构

1
2
3
4
5
6
7
# tokenizer_manager.py:30-35
@dataclasses.dataclass
class ReqState:
    out_list: List          # 该请求的所有输出片段
    finished: bool          # 该请求是否完成
    event: asyncio.Event    # 异步事件信号
    lock: asyncio.Lock      # 异步锁

状态注册与管理

1
2
3
4
5
6
7
8
9
10
11
# tokenizer_manager.py:101, 143-144
class TokenizerManager:
    def __init__(self, ...):
        self.rid_to_state = {}  # 全局请求状态字典

    async def generate_request(self, obj: GenerateReqInput):
        # 为每个请求创建状态
        lock = asyncio.Lock()
        event = asyncio.Event()
        state = ReqState([], False, event, lock)
        self.rid_to_state[rid] = state  # 注册到全局字典

核心作用:

  • 请求映射ridReqState 的映射关系
  • 状态隔离:每个请求有独立的 out_list,互不干扰
  • 异步同步:通过 asyncio.Event 实现生产者-消费者同步

双循环异步通信机制

生产者-消费者模式

这个设计类似于 C++ 中的队列+信号量机制:

循环1:generate_request 中的消费循环

1
2
3
4
5
6
7
8
9
# tokenizer_manager.py:146-153
while True:
    await event.wait()               # 等待信号(类似等待队列非空)
    yield state.out_list[-1]         # 消费数据(从队列取出)
    state.out_list = []              # 清空已消费数据
    if state.finished:
        del self.rid_to_state[rid]   # 完成后清理状态
        break
    event.clear()                    # 重置信号(类似重置信号量)

循环2:handle_loop 中的生产循环

1
2
3
4
# tokenizer_manager.py:214-217
state.out_list.append(out_dict)      # 生产数据(入队)
state.finished = recv_obj.finished[i]
state.event.set()                    # 发送信号(V操作)