mamba-radix-cache-extra-buffer

mamba-radix-cache-extra-buffer
gogongxtSGLang
Mamba Radix Cache 深度解析:extra_buffer 模式的 Ping-Pong
缓冲机制
#14792 | e61dabf5e |
2025-12-14 | [Qwen3-next] support mamba radix cache for overlap
scheduler - 引入
#15180 | 36fcf71ff | 2025-12-16 | [Qwen3-next] Add PD disaggregation support for mamba with extra_buffer - 为
extra_buffer 模式,支持 overlap
scheduler 和 Ping-Pong 缓冲机制 |#15180 | 36fcf71ff | 2025-12-16 | [Qwen3-next] Add PD disaggregation support for mamba with extra_buffer - 为
extra_buffer 添加 Prefill-Decode
disaggregation 支持 |相关配置参数
| 参数 | 默认值 | 说明 |
|---|---|---|
--mamba-scheduler-strategy |
auto |
调度策略:no_buffer、extra_buffer 或
auto(自动选择) |
--mamba-track-interval |
256 |
Decode 阶段的状态追踪间隔(token 数) |
--mamba-cache-chunk-size |
64 |
Prefill 阶段的缓存粒度,取
max(FLA_CHUNK_SIZE, page_size) |
--mamba-ssm-dtype |
float32 |
SSM 状态的数据类型,也可以设置成bfloat16,float16 |
--mamba-full-memory-ratio |
0.9 |
Mamba 状态内存与 KV cache 内存的比率 |
--max-mamba-cache-size |
None |
Mamba cache 的最大大小 |
关键参数详解:
mamba_cache_chunk_size:Prefill 阶段状态追踪的粒度。当请求长度超过此值时,在每个 chunk 边界追踪状态。默认取max(FLA_CHUNK_SIZE, page_size) = max(64, page_size),确保满足 FLA 64 位对齐要求。由于page_size默认为 1,实际默认值为 64。mamba_track_interval:Decode 阶段每隔多少 token 追踪一次状态。较小的值提供更细粒度的前缀匹配机会,但会增加追踪开销。ping_pong_track_buffer_size:每个请求额外分配的追踪 buffer 数量。Overlap scheduler 模式下为 2(真正的 Ping-Pong),非 overlap 模式为 1。
背景
在上一篇博客中,我们深入分析了 no_buffer
模式下的状态管理机制。我们看到,no_buffer 模式通过
fork_from 操作在请求完成时复制 Mamba 状态到新的 slot,确保
Radix Tree 持有独立的状态副本。
然而,这种每次缓存都需要复制状态的机制存在性能开销,还无法实现overlap调度。本篇将介绍另一种调度策略——extra_buffer,它通过
Ping-Pong 缓冲机制避免了频繁的状态复制。
核心问题:no_buffer
的性能瓶颈
在 no_buffer 模式下,每次缓存操作都涉及:
1
2
3
4
5
6
7
请求完成或 Chunk 边界
↓
fork_from(src_slot) → 分配新 slot + 复制状态
↓
插入 Radix Tree
↓
释放原始 slot状态复制的开销:
1
2
3
4
5
Mamba 状态大小(Qwen3-Next 单层):
- conv_state: [conv_dim/tp, conv_kernel-1] ≈ 1KB
- ssm_state: [num_heads/tp, head_dim, dstate] = [32, 128, 128] ≈ 2MB
总计(假设 32 层 Mamba): ~64MB 每次 fork当缓存命中率高时,频繁的状态复制会成为性能瓶颈。
extra_buffer
的核心思想
核心洞察:如果在请求执行过程中就预留额外的缓冲区来追踪中间状态,就可以在缓存时直接转移所有权,避免复制。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
┌─────────────────────────────────────────────────────────────────────────┐
│ extra_buffer 的 Ping-Pong 机制 │
└─────────────────────────────────────────────────────────────────────────┘
请求初始化时分配:
- mamba_pool_idx: slot A (用于 forward 计算)
- mamba_ping_pong_track_buffer: [slot B, slot C] (用于追踪中间状态)
Forward 计算:
- 使用 slot A 计算当前状态
- 在 track interval 边界,将状态追踪到 track_buffer[slot B]
缓存时:
- Radix Tree 接管 track_buffer[slot B]
- 请求保留 track_buffer[slot C] 用于后续追踪
- 无需状态复制!数据结构详解
1. 请求级别的追踪缓冲区
1
2
3
4
5
6
7
8
class Req:
# Forward 计算使用的 slot(始终存在)
mamba_pool_idx: int
# extra_buffer 模式新增
mamba_ping_pong_track_buffer: torch.Tensor # shape [2],存储两个 slot 索引
mamba_next_track_idx: int # 0 或 1,当前使用的 buffer 索引
mamba_last_track_seqlen: int # 上次追踪时的序列长度2. Pool 级别的映射表
1
2
3
4
5
class HybridReqToTokenPool:
# 仅在 extra_buffer 模式下分配
req_index_to_mamba_ping_pong_track_buffer_mapping: torch.Tensor
# shape: [size, ping_pong_track_buffer_size]
# ping_pong_track_buffer_size = 2 (overlap schedule) 或 1 (非 overlap)3. 关键配置参数
1
2
3
4
5
6
7
8
9
# server_args.py
mamba_scheduler_strategy: str = "auto" # "no_buffer" 或 "extra_buffer"
mamba_track_interval: int = 256 # Decode 阶段的追踪间隔
# 动态计算
@property
def mamba_cache_chunk_size(self) -> int:
# Prefill 阶段的缓存粒度
return max(FLA_CHUNK_SIZE, self.page_size) # 通常为 512 或更大状态追踪机制
Prefill 阶段:Chunk 边界追踪
当请求长度超过 mamba_cache_chunk_size 时,在每个 chunk
边界追踪状态:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def _init_mamba_tracking(req, ...):
mamba_cache_chunk_size = get_global_server_args().mamba_cache_chunk_size
mask = req.extend_input_len >= mamba_cache_chunk_size
if mask:
# 计算对齐的追踪位置
mamba_track_seqlen_aligned = (
len(req.prefix_indices)
+ (req.extend_input_len // mamba_cache_chunk_size) * mamba_cache_chunk_size
)
# 切换 ping-pong buffer
req.mamba_next_track_idx = get_mamba_ping_pong_other_idx(req.mamba_next_track_idx)
req.mamba_last_track_seqlen = mamba_track_seqlen_aligned缓存长度必须是 64 的整数倍。如果分叉位置是 300,则缓存位置会回退到
- 如果 当前路径上从未发生分叉 → prefix = 最后一个 64 对齐点
- 如果 中途发生过分叉 → prefix = 分叉点的最后一个 64 对齐点
流程示意:
假设依次来了以下请求:
- CONTENT = “1” * 256 + “2” * 256 # prefix 0
- CONTENT = “1” * 256 + “3” * 256 # prefix 0
- CONTENT = “1” * 256 + “4” * 256 # prefix 256
- CONTENT = “1” * 256 + “2” * 256 + “5” * 256 # prefix 512
- CONTENT = “1” * 256 + “4” * 256 + “5” * 256 # prefix 512
- CONTENT = “1” * 256 + “3” * 256 + “5” * 256 # prefix 256
- CONTENT = “1” * 256 + “3” * 256 # prefix 256
- CONTENT = “1” * 256 + “3” * 256 # prefix 512
graph TD
subgraph S1["R1: (1+2) -> Prefix 0"]
r1((Root)) --- n1_1["1 (256)"] --- n1_2["2 (256)"]
style n1_2 fill:#4CAF50,color:#fff
end
subgraph S2["R2: (1+3) -> Prefix 0"]
r2((Root)) --- n2_1["1 (256)"]
n2_1 --- n2_2["2 (256)"]
n2_1 --- n2_3["3 (256)"]
style n2_1 stroke-dasharray: 5 5
style n2_2 fill:#4CAF50,color:#fff
style n2_1 fill:#4CAF50,color:#fff
end
subgraph S3["R3: (1+4) -> Prefix 256"]
r3((Root)) --- n3_1["1 (256)"]
n3_1 --- n3_2["2 (256)"]
n3_1 --- n3_3["3 (256)"]
n3_1 --- n3_4["4 (256)"]
style n3_1 fill:#4CAF50,color:#fff
style n3_2 fill:#4CAF50,color:#fff
style n3_4 fill:#4CAF50,color:#fff
end
subgraph S4["R4: (1+2+5) -> Prefix 512"]
r4((Root)) --- n4_1["1 (256)"]
n4_1 --- n4_2["2 (256)"]
n4_2 --- n4_5["5 (256)"]
n4_1 --- n4_3["3 (256)"]
n4_1 --- n4_4["4 (256)"]
style n4_1 fill:#4CAF50,color:#fff
style n4_2 fill:#4CAF50,color:#fff
style n4_4 fill:#4CAF50,color:#fff
style n4_5 fill:#4CAF50,color:#fff
end
subgraph S5["R5: (1+4+5) -> Prefix 512"]
r5((Root)) --- n5_1["1 (256)"]
n5_1 --- n5_2["2 (256)"]
n5_2 --- n5_2_1["5 (256)"]
n5_1 --- n5_3["3 (256)"]
n5_1 --- n5_4["4 (256)"]
n5_4 --- n5_5["5 (256)"]
style n5_1 fill:#4CAF50,color:#fff
style n5_4 fill:#4CAF50,color:#fff
style n5_2 fill:#4CAF50,color:#fff
style n5_2_1 fill:#4CAF50,color:#fff
style n5_5 fill:#4CAF50,color:#fff
end
subgraph S6["R6: (1+3+5) -> Prefix 256"]
r6((Root)) --- n6_1["1 (256)"]
n6_1 --- n6_2["2 (256)"]
n6_2 --- n6_2_1["5 (256)"]
n6_1 --- n6_3["3 (256)"]
n6_3 --- n6_5["5 (256)"]
n6_1 --- n6_4["4 (256)"]
n6_4 --- n6_4_1["5 (256)"]
style n6_2 fill:#4CAF50,color:#fff
style n6_2_1 fill:#4CAF50,color:#fff
style n6_1 fill:#4CAF50,color:#fff
style n6_4 fill:#4CAF50,color:#fff
style n6_4_1 fill:#4CAF50,color:#fff
style n6_5 fill:#4CAF50,color:#fff
end
subgraph S7["R7: (1+3) -> Prefix 256"]
r7((Root)) --- n7_1["1 (256)"]
n7_1 --- n7_2["2 (256)"]
n7_2 --- n7_2_1["5 (256)"]
n7_1 --- n7_3["3 (256)"]
n7_3 --- n7_5["5 (256)"]
n7_1 --- n7_4["4 (256)"]
n7_4 --- n7_4_1["5 (256)"]
style n7_1 fill:#4CAF50,color:#fff
style n7_2 fill:#4CAF50,color:#fff
style n7_2_1 fill:#4CAF50,color:#fff
style n7_3 fill:#4CAF50,color:#fff
style n7_5 fill:#4CAF50,color:#fff
style n7_4 fill:#4CAF50,color:#fff
style n7_4_1 fill:#4CAF50,color:#fff
end
subgraph S8["R8: (1+3) -> Prefix 512"]
r8((Root)) --- n8_1["1 (256)"]
n8_1 --- n8_2["..."]
n8_1["1 (256)"] --- n8_3["3 (256)"]
n8_3["3 (256)"] --- n8_3_1["..."]
n8_1 --- n8_4["..."]
style n8_1 fill:#4CAF50,color:#fff
style n8_2 fill:#4CAF50,color:#fff
style n8_3 fill:#4CAF50,color:#fff
style n8_4 fill:#4CAF50,color:#fff
style n8_3_1 fill:#4CAF50,color:#fff
end
S1 --> S2 --> S3 --> S4 --> S5 --> S6 --> S7 --> S8
Decode 阶段:固定间隔追踪
在 Decode 阶段,每隔 mamba_track_interval(默认
256)追踪一次:
1
2
3
4
5
6
7
def _mamba_prefix_cache_update(req, batch, ...):
seq_len = len(req.origin_input_ids) + len(req.output_ids) - 1
if seq_len % mamba_track_interval == 0:
# 切换到另一个 buffer
req.mamba_next_track_idx = get_mamba_ping_pong_other_idx(req.mamba_next_track_idx)
req.mamba_last_track_seqlen = seq_len缓存操作详解
cache_finished_req:请求完成时的缓存
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def cache_finished_req(self, req: Req, is_insert: bool = True):
if self.enable_mamba_extra_buffer:
# 获取要保留的 buffer 索引("另一个" buffer)
mamba_ping_pong_track_buffer_to_keep = (
self.req_to_token_pool.get_mamba_ping_pong_other_idx(
req.mamba_next_track_idx
)
)
# 直接使用该 buffer 的 slot 索引,无需复制!
mamba_value = (
req.mamba_ping_pong_track_buffer[mamba_ping_pong_track_buffer_to_keep]
.unsqueeze(-1)
.clone()
)
else:
# no_buffer 模式:需要 fork
mamba_value = req.mamba_pool_idx.unsqueeze(-1).clone()
mamba_value_forked = self.mamba_pool.fork_from(mamba_value)
mamba_value = mamba_value_forked
# 插入 Radix Tree
result = self.insert(InsertParams(key=..., value=..., mamba_value=mamba_value))
# extra_buffer 模式下,释放"当前" buffer,保留"另一个"给 Radix Tree
self.req_to_token_pool.free_mamba_cache(
req,
mamba_ping_pong_track_buffer_to_keep=mamba_ping_pong_track_buffer_to_keep,
)关键区别:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
┌─────────────────────────────────────────────────────────────────────────┐
│ extra_buffer vs no_buffer: cache_finished_req │
└─────────────────────────────────────────────────────────────────────────┘
extra_buffer:
请求的 track_buffer: [slot B, slot C]
当前使用: slot C (mamba_next_track_idx = 1)
缓存时:
├─ mamba_ping_pong_track_buffer_to_keep = 0 (slot B)
├─ Radix Tree 直接使用 slot B(无需复制)
└─ 释放 slot C,请求结束
no_buffer:
请求使用: slot A (mamba_pool_idx)
缓存时:
├─ fork_from(slot A) → slot D(分配新 slot + 复制状态)
├─ Radix Tree 使用 slot D
└─ 释放 slot Acache_unfinished_req:Chunk 中间缓存
对于长请求的 Chunk 边界缓存,逻辑类似:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def cache_unfinished_req(self, req: Req, chunked=False):
if self.enable_mamba_extra_buffer:
mamba_ping_pong_track_buffer_to_keep = (
self.req_to_token_pool.get_mamba_ping_pong_other_idx(
req.mamba_next_track_idx
)
)
mamba_value = (
req.mamba_ping_pong_track_buffer[mamba_ping_pong_track_buffer_to_keep]
.unsqueeze(-1)
.clone()
)
else:
mamba_value = self.req_to_token_pool.get_mamba_indices(req.req_pool_idx)
mamba_value_forked = self.mamba_pool.fork_from(mamba_value)
mamba_value = mamba_value_forked
result = self.insert(...)
# 注意:这里不释放请求的 mamba slot!
# extra_buffer: 请求继续使用另一个 buffer
# no_buffer: 请求继续使用原始 slot完整流程图
短请求流程(长度 < mamba_cache_chunk_size)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
┌─────────────────────────────────────────────────────────────────────────┐
│ 短请求处理流程(extra_buffer 模式,无 Chunk 边界) │
└─────────────────────────────────────────────────────────────────────────┘
1. 请求初始化
│
├─ alloc() 分配:
│ ├─ mamba_pool_idx = slot A
│ └─ mamba_ping_pong_track_buffer = [slot B, slot C]
│ └─ mamba_next_track_idx = 0
│
▼
2. Forward 计算
│
├─ 使用 slot A 计算,更新 Mamba 状态
├─ 长度 < mamba_cache_chunk_size,不触发追踪
│
▼
3. cache_finished_req()
│
├─ mamba_ping_pong_track_buffer_to_keep = 1 (slot C)
│ └─ get_mamba_ping_pong_other_idx(0) = 1
│
├─ Radix Tree 接管 slot C
│ └─ insert(mamba_value=slot C)
│
├─ free_mamba_cache():
│ ├─ 释放 slot A (mamba_pool_idx)
│ └─ 释放 slot B (track_buffer[0])
│ └─ 保留 slot C (已在 Radix Tree)
│
▼
4. 后续请求匹配
│
└─ 找到 slot C → COW 复制到新请求空间长请求流程(Chunked Prefill)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
┌─────────────────────────────────────────────────────────────────────────┐
│ 长请求处理流程(extra_buffer 模式,mamba_cache_chunk_size=512) │
│ 请求长度 1200 │
└─────────────────────────────────────────────────────────────────────────┘
1. 请求初始化
│
├─ mamba_pool_idx = slot A
├─ mamba_ping_pong_track_buffer = [slot B, slot C]
└─ mamba_next_track_idx = 0
2. Chunk 1: tokens[0:512]
│
├─ Forward 计算,更新 slot A
│
├─ 追踪点 512:
│ ├─ 状态写入 track_buffer[0] (slot B)
│ └─ mamba_next_track_idx = 1
│
└─ cache_unfinished_req():
├─ mamba_ping_pong_track_buffer_to_keep = 0
├─ Radix Tree 接管 slot B
├─ 请求保留 [None, slot C](slot B 已转移)
└─ 不释放 slot A 和 slot C
3. Chunk 2: tokens[512:1024]
│
├─ Forward 计算,更新 slot A
│
├─ 追踪点 1024:
│ ├─ 状态写入 track_buffer[1] (slot C)
│ └─ mamba_next_track_idx = 0
│
└─ cache_unfinished_req():
├─ mamba_ping_pong_track_buffer_to_keep = 1
├─ Radix Tree 接管 slot C
└─ 请求保留 [None, None](两个 buffer 都已转移)
4. Chunk 3: tokens[1024:1200]
│
├─ Forward 计算,更新 slot A
├─ 没有新的追踪点
│
└─ cache_finished_req():
├─ mamba_ping_pong_track_buffer_to_keep = 1 (上次的值)
├─ 但 track_buffer[1] 已经是 None(已转移)
├─ 使用 slot A 的状态(需要 fork)
└─ 释放所有剩余 slot
5. Radix Tree 最终状态
│
├─ 节点 [0:512]: mamba_value = slot B
├─ 节点 [0:1024]: mamba_value = slot C
└─ 节点 [0:1200]: mamba_value = slot D (fork 自 slot A)实现细节:Ping-Pong Buffer 的生命周期
分配时机
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# memory_pool.py: alloc()
def alloc(self, reqs: List["Req"]):
for req in reqs:
# 总是分配 mamba_pool_idx
mid = self.mamba_pool.alloc(1)
req.mamba_pool_idx = mid[0]
# extra_buffer 模式额外分配追踪 buffer
if self.enable_mamba_extra_buffer:
if req.mamba_ping_pong_track_buffer is None:
req.mamba_ping_pong_track_buffer = self.mamba_pool.alloc(
self.mamba_ping_pong_track_buffer_size # 通常为 2
)
req.mamba_next_track_idx = 0释放逻辑
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# memory_pool.py: free_mamba_cache()
def free_mamba_cache(self, req, mamba_ping_pong_track_buffer_to_keep=None):
# 总是释放 mamba_pool_idx
self.mamba_pool.free(req.mamba_pool_idx)
req.mamba_pool_idx = None
if self.enable_mamba_extra_buffer:
buffer_to_free = self.req_index_to_mamba_ping_pong_track_buffer_mapping[req.req_pool_idx]
if mamba_ping_pong_track_buffer_to_keep is not None:
# 保留一个 buffer 给 Radix Tree,释放另一个
idx_to_free = 1 - mamba_ping_pong_track_buffer_to_keep
buffer_to_free = buffer_to_free[idx_to_free : idx_to_free + 1]
self.mamba_pool.free(buffer_to_free)总结
extra_buffer 模式通过 Ping-Pong
缓冲机制,将状态追踪与状态计算解耦:
| 设计要点 | 说明 |
|---|---|
| Ping-Pong Buffer | 两个 slot 交替使用,避免覆盖正在使用的状态 |
| 所有权转移 | 缓存时直接转移 buffer 所有权,无需复制 |
| 追踪粒度 | Prefill: mamba_cache_chunk_size,Decode:
mamba_track_interval |
| 内存-性能权衡 | 以 3 倍内存开销换取零复制性能 |
评论
匿名评论隐私政策




