sglang投机采样mtp

如果开启了MTP,那么调用的worker是EAGLEWorker:class EAGLEWorker(TpModelWorker):

单次的decode推理大体流程为:

1
2
3
4
5
6
7
8
9
10
11
12
13
单次decode forward耗时12ms,平均得到3个token

- forward 11.4ms
    -  draft 1.8ms
        - _draft_preprocess_decode
        - init_forward_metadata_replay_cuda_graph 0.45ms
        - cudagraph 0.6ms
    -  verify 9.5ms
        - replay_prepare 1ms
        - cudagraph 6.2ms
        - mamba_verify_update 0.86ms
            - update_mamba_state_after_mtp_verify 0.47ms (这里有算子融合可以优化)
- process_batch_resule 0.7ms

分配kvcache slot空间

1
2
3
4
5
6
7
8
9
10
11
# Allocate cache locations
# Layout of the out_cache_loc
# [       topk 0         ] [       topk 1         ]
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
if self.page_size == 1:
  # TODO: We only need self.speculative_num_steps - 1 * topk cache loc
  out_cache_loc, token_to_kv_pool_state_backup = alloc_token_slots(
      batch.tree_cache,
      num_seqs * self.speculative_num_steps * self.topk,
      backup_state=True,
  )

这里的 out_cache_loc 是分配 KV cache的token slot位置,用于存储Eagle speculative decoding过程中的中间状态。

数值含义分解

num_seqs * self.speculative_num_steps * self.topk:

  • num_seqs: 批次中的序列数量
  • self.speculative_num_steps: 推测性解码的步数(每次draft生成多少个token)
  • self.topk: Eagle模型在每个位置考虑的候选token数量

为什么需要这个大小

根据注释中的布局说明:

[ topk 0 ] [ topk 1 ]
[iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]

Eagle的draft阶段需要:
1. 对每个序列(num_seqs)
2. 在每个推测步数(speculative_num_steps)
3. 为每个topk候选(topk)

都分配一个KV cache slot来存储该draft分支的中间状态。

实际例子

假设:
- num_seqs = 8(batch size)
- speculative_num_steps = 4(每次draft 4个token)
- topk = 4(每个位置4个候选)

则分配:8 × 4 × 4 = 128个token slots

这些slots用于存储所有可能的draft分支路径的KV cache,以便后续验证和选择最优路径。

assign_draft_cache_locs_with_sum 算子

这个Triton算子是Eagle speculative decoding中最关键的cache管理内核,主要负责将draft阶段的KV cache位置分配并复制到正确的位置。

核心功能(分三部分)

Part 1: 将 out_cache_loc 复制到 req_to_token (L305-313)

1
2
3
4
# 将分配的cache位置写入token pool
kv_start = my_seq_len  # 从当前序列长度开始
token_pool = req_to_token + req_pool_index * pool_len
# 复制 copy_len = topk * speculative_num_steps 个位置

这一步将之前分配的out_cache_loc写入到每个请求的token池中,让后续的kernel知道KV cache在哪里。

Part 2: 处理跨页复制 (L314-346)

page_size != 1topk > 1 时,需要处理KV cache的跨页复制:

场景:假设当前序列最后一页没满(last_page_len=3,page_size=4)

1
2
topk=0: [token1, token2, token3, draft0_0, draft0_1, ...]  # 原始分支
topk=1: [token1, token2, token3, draft1_0, draft1_1, ...]  # 新分支需要复制前缀

生成 source_cache_loc 和 target_cache_loc: - source: 原始前缀的位置 [8, 9, 10, 8, 9, 10...] - target: 新分支的目标位置 [16, 17, 18, 24, 25, 26...]

这样后续kernel可以将前缀KV cache复制到新分支。

Part 3: 整理输出cache位置 (L347-380)

将实际的draft token位置整理回 out_cache_loc,跳过前缀部分:

1
2
3
4
 speculative_num_steps=5, page_size=4, last_page_len=1
 原始: - xxxxx .. | - xxxxx .. |
        ^ prefix   ^ speculative tokens
 输出: 只取 "xxxxx" 部分,去掉前缀 "-"

副功能:计算seq_lens总和 (L296-303)

1
2
3
4
5
if pid == num_programs - 1:
    # 最后一个program计算所有序列长度总和
    all_seq_lens = tl.load(seq_lens + bs_offset, mask=bs_offset < num_programs)
    total_sum = tl.sum(all_seq_lens)
    tl.store(seq_lens_sum_output, total_sum)

避免了GPU-CPU同步(不需要torch.sum().item()),直接在GPU上完成求和。

参数含义

参数 含义
req_pool_indices 每个请求在pool中的索引
req_to_token token池,存储KV cache位置映射
extend_lens 每个序列需要扩展的长度
out_cache_loc 输出:分配的cache位置数组
source_cache_loc 输出:KV cache复制的源位置
target_cache_loc 输出:KV cache复制的目标位置
topk, speculative_num_steps Eagle的超参数
page_size KV cache的页大小

关键优化点

  1. 并行处理:每个sequence一个program
  2. 避免GPU-CPU同步:在GPU上直接计算seq_lens总和
  3. 处理复杂边界:正确处理跨页复制和前缀共享

这个算子是Eagle性能优化的核心,确保多个draft分支能高效共享和复制KV cache。

恢复token_to_kv_pool_allocator状态

self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)

为什么分配后要恢复状态呢,因为draft实际上是一次假分配,拿到可以用的slots后就不用管了