mamba-radix-cache-extra-buffer

SGLang Mamba Radix Cache 深度解析:extra_buffer 模式的 Ping-Pong 缓冲机制

#14792 | e61dabf5e | 2025-12-14 | [Qwen3-next] support mamba radix cache for overlap scheduler - 引入 extra_buffer 模式,支持 overlap scheduler 和 Ping-Pong 缓冲机制 |
#15180 | 36fcf71ff | 2025-12-16 | [Qwen3-next] Add PD disaggregation support for mamba with extra_buffer - 为 extra_buffer 添加 Prefill-Decode disaggregation 支持 |

相关配置参数

参数 默认值 说明
--mamba-scheduler-strategy auto 调度策略:no_bufferextra_bufferauto(自动选择)
--mamba-track-interval 256 Decode 阶段的状态追踪间隔(token 数)
--mamba-cache-chunk-size 64 Prefill 阶段的缓存粒度,取 max(FLA_CHUNK_SIZE, page_size)
--mamba-ssm-dtype float32 SSM 状态的数据类型,也可以设置成bfloat16,float16
--mamba-full-memory-ratio 0.9 Mamba 状态内存与 KV cache 内存的比率
--max-mamba-cache-size None Mamba cache 的最大大小

关键参数详解

  • mamba_cache_chunk_size:Prefill 阶段状态追踪的粒度。当请求长度超过此值时,在每个 chunk 边界追踪状态。默认取 max(FLA_CHUNK_SIZE, page_size) = max(64, page_size),确保满足 FLA 64 位对齐要求。由于 page_size 默认为 1,实际默认值为 64。
  • mamba_track_interval:Decode 阶段每隔多少 token 追踪一次状态。较小的值提供更细粒度的前缀匹配机会,但会增加追踪开销。
  • ping_pong_track_buffer_size:每个请求额外分配的追踪 buffer 数量。Overlap scheduler 模式下为 2(真正的 Ping-Pong),非 overlap 模式为 1。

背景

在上一篇博客中,我们深入分析了 no_buffer 模式下的状态管理机制。我们看到,no_buffer 模式通过 fork_from 操作在请求完成时复制 Mamba 状态到新的 slot,确保 Radix Tree 持有独立的状态副本。

然而,这种每次缓存都需要复制状态的机制存在性能开销,还无法实现overlap调度。本篇将介绍另一种调度策略——extra_buffer,它通过 Ping-Pong 缓冲机制避免了频繁的状态复制。

核心问题:no_buffer 的性能瓶颈

no_buffer 模式下,每次缓存操作都涉及:

1
2
3
4
5
6
7
请求完成或 Chunk 边界

fork_from(src_slot) → 分配新 slot + 复制状态

插入 Radix Tree

释放原始 slot

状态复制的开销

1
2
3
4
5
Mamba 状态大小(Qwen3-Next 单层):
- conv_state: [conv_dim/tp, conv_kernel-1] ≈ 1KB
- ssm_state: [num_heads/tp, head_dim, dstate] = [32, 128, 128] ≈ 2MB

总计(假设 32 层 Mamba): ~64MB 每次 fork

当缓存命中率高时,频繁的状态复制会成为性能瓶颈。

extra_buffer 的核心思想

核心洞察:如果在请求执行过程中就预留额外的缓冲区来追踪中间状态,就可以在缓存时直接转移所有权,避免复制。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
┌─────────────────────────────────────────────────────────────────────────┐
│                    extra_buffer 的 Ping-Pong 机制                       │
└─────────────────────────────────────────────────────────────────────────┘

请求初始化时分配:
- mamba_pool_idx: slot A (用于 forward 计算)
- mamba_ping_pong_track_buffer: [slot B, slot C] (用于追踪中间状态)

Forward 计算:
  - 使用 slot A 计算当前状态
  - 在 track interval 边界,将状态追踪到 track_buffer[slot B]

缓存时:
  - Radix Tree 接管 track_buffer[slot B]
  - 请求保留 track_buffer[slot C] 用于后续追踪
  - 无需状态复制!

数据结构详解

1. 请求级别的追踪缓冲区

1
2
3
4
5
6
7
8
class Req:
    # Forward 计算使用的 slot(始终存在)
    mamba_pool_idx: int

    # extra_buffer 模式新增
    mamba_ping_pong_track_buffer: torch.Tensor  # shape [2],存储两个 slot 索引
    mamba_next_track_idx: int                    # 0 或 1,当前使用的 buffer 索引
    mamba_last_track_seqlen: int                 # 上次追踪时的序列长度

2. Pool 级别的映射表

1
2
3
4
5
class HybridReqToTokenPool:
    # 仅在 extra_buffer 模式下分配
    req_index_to_mamba_ping_pong_track_buffer_mapping: torch.Tensor
        # shape: [size, ping_pong_track_buffer_size]
        # ping_pong_track_buffer_size = 2 (overlap schedule) 或 1 (非 overlap)

3. 关键配置参数

1
2
3
4
5
6
7
8
9
# server_args.py
mamba_scheduler_strategy: str = "auto"      # "no_buffer" 或 "extra_buffer"
mamba_track_interval: int = 256             # Decode 阶段的追踪间隔

# 动态计算
@property
def mamba_cache_chunk_size(self) -> int:
    # Prefill 阶段的缓存粒度
    return max(FLA_CHUNK_SIZE, self.page_size)  # 通常为 512 或更大

状态追踪机制

Prefill 阶段:Chunk 边界追踪

当请求长度超过 mamba_cache_chunk_size 时,在每个 chunk 边界追踪状态:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def _init_mamba_tracking(req, ...):
    mamba_cache_chunk_size = get_global_server_args().mamba_cache_chunk_size
    mask = req.extend_input_len >= mamba_cache_chunk_size

    if mask:
        # 计算对齐的追踪位置
        mamba_track_seqlen_aligned = (
            len(req.prefix_indices)
            + (req.extend_input_len // mamba_cache_chunk_size) * mamba_cache_chunk_size
        )

        # 切换 ping-pong buffer
        req.mamba_next_track_idx = get_mamba_ping_pong_other_idx(req.mamba_next_track_idx)
        req.mamba_last_track_seqlen = mamba_track_seqlen_aligned

缓存长度必须是 64 的整数倍。如果分叉位置是 300,则缓存位置会回退到

  • 如果 当前路径上从未发生分叉 → prefix = 最后一个 64 对齐点
  • 如果 中途发生过分叉 → prefix = 分叉点的最后一个 64 对齐点

流程示意

假设依次来了以下请求:

  1. CONTENT = “1” * 256 + “2” * 256 # prefix 0
  2. CONTENT = “1” * 256 + “3” * 256 # prefix 0
  3. CONTENT = “1” * 256 + “4” * 256 # prefix 256
  4. CONTENT = “1” * 256 + “2” * 256 + “5” * 256 # prefix 512
  5. CONTENT = “1” * 256 + “4” * 256 + “5” * 256 # prefix 512
  6. CONTENT = “1” * 256 + “3” * 256 + “5” * 256 # prefix 256
  7. CONTENT = “1” * 256 + “3” * 256 # prefix 256
  8. CONTENT = “1” * 256 + “3” * 256 # prefix 512
graph TD
    subgraph S1["R1: (1+2) -> Prefix 0"]
        r1((Root)) --- n1_1["1 (256)"] --- n1_2["2 (256)"]
        style n1_2 fill:#4CAF50,color:#fff
    end

    subgraph S2["R2: (1+3) -> Prefix 0"]
        r2((Root)) --- n2_1["1 (256)"]
        n2_1 --- n2_2["2 (256)"]
        n2_1 --- n2_3["3 (256)"]
        style n2_1 stroke-dasharray: 5 5
        style n2_2 fill:#4CAF50,color:#fff
        style n2_1 fill:#4CAF50,color:#fff
    end

    subgraph S3["R3: (1+4) -> Prefix 256"]
        r3((Root)) --- n3_1["1 (256)"]
        n3_1 --- n3_2["2 (256)"]
        n3_1 --- n3_3["3 (256)"]
        n3_1 --- n3_4["4 (256)"]
        style n3_1 fill:#4CAF50,color:#fff
        style n3_2 fill:#4CAF50,color:#fff
        style n3_4 fill:#4CAF50,color:#fff
    end

    subgraph S4["R4: (1+2+5) -> Prefix 512"]
        r4((Root)) --- n4_1["1 (256)"]
        n4_1 --- n4_2["2 (256)"]
        n4_2 --- n4_5["5 (256)"]
        n4_1 --- n4_3["3 (256)"]
        n4_1 --- n4_4["4 (256)"]
        style n4_1 fill:#4CAF50,color:#fff
        style n4_2 fill:#4CAF50,color:#fff
        style n4_4 fill:#4CAF50,color:#fff
        style n4_5 fill:#4CAF50,color:#fff
    end

    subgraph S5["R5: (1+4+5) -> Prefix 512"]
        r5((Root)) --- n5_1["1 (256)"]
        n5_1 --- n5_2["2 (256)"]
        n5_2 --- n5_2_1["5 (256)"]
        n5_1 --- n5_3["3 (256)"]
        n5_1 --- n5_4["4 (256)"]
        n5_4 --- n5_5["5 (256)"]
        style n5_1 fill:#4CAF50,color:#fff
        style n5_4 fill:#4CAF50,color:#fff
        style n5_2 fill:#4CAF50,color:#fff
        style n5_2_1 fill:#4CAF50,color:#fff
        style n5_5 fill:#4CAF50,color:#fff
    end

    subgraph S6["R6: (1+3+5) -> Prefix 256"]
        r6((Root)) --- n6_1["1 (256)"]
        n6_1 --- n6_2["2 (256)"]
        n6_2 --- n6_2_1["5 (256)"]
        n6_1 --- n6_3["3 (256)"]
        n6_3 --- n6_5["5 (256)"]
        n6_1 --- n6_4["4 (256)"]
        n6_4 --- n6_4_1["5 (256)"]
        style n6_2 fill:#4CAF50,color:#fff
        style n6_2_1 fill:#4CAF50,color:#fff
        style n6_1 fill:#4CAF50,color:#fff
        style n6_4 fill:#4CAF50,color:#fff
        style n6_4_1 fill:#4CAF50,color:#fff
        style n6_5 fill:#4CAF50,color:#fff
    end

    subgraph S7["R7: (1+3) -> Prefix 256"]
        r7((Root)) --- n7_1["1 (256)"]
        n7_1 --- n7_2["2 (256)"]
        n7_2 --- n7_2_1["5 (256)"]
        n7_1 --- n7_3["3 (256)"]
        n7_3 --- n7_5["5 (256)"]
        n7_1 --- n7_4["4 (256)"]
        n7_4 --- n7_4_1["5 (256)"]
        style n7_1 fill:#4CAF50,color:#fff
        style n7_2 fill:#4CAF50,color:#fff
        style n7_2_1 fill:#4CAF50,color:#fff
        style n7_3 fill:#4CAF50,color:#fff
        style n7_5 fill:#4CAF50,color:#fff
        style n7_4 fill:#4CAF50,color:#fff
        style n7_4_1 fill:#4CAF50,color:#fff
    end

    subgraph S8["R8: (1+3) -> Prefix 512"]
        r8((Root)) --- n8_1["1 (256)"]
        n8_1 --- n8_2["..."]
        n8_1["1 (256)"] --- n8_3["3 (256)"]
        n8_3["3 (256)"] --- n8_3_1["..."]
        n8_1 --- n8_4["..."]
        style n8_1 fill:#4CAF50,color:#fff
        style n8_2 fill:#4CAF50,color:#fff
        style n8_3 fill:#4CAF50,color:#fff
        style n8_4 fill:#4CAF50,color:#fff
        style n8_3_1 fill:#4CAF50,color:#fff
    end

    S1 --> S2 --> S3 --> S4 --> S5 --> S6 --> S7 --> S8

Decode 阶段:固定间隔追踪

在 Decode 阶段,每隔 mamba_track_interval(默认 256)追踪一次:

1
2
3
4
5
6
7
def _mamba_prefix_cache_update(req, batch, ...):
    seq_len = len(req.origin_input_ids) + len(req.output_ids) - 1

    if seq_len % mamba_track_interval == 0:
        # 切换到另一个 buffer
        req.mamba_next_track_idx = get_mamba_ping_pong_other_idx(req.mamba_next_track_idx)
        req.mamba_last_track_seqlen = seq_len

缓存操作详解

cache_finished_req:请求完成时的缓存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def cache_finished_req(self, req: Req, is_insert: bool = True):
    if self.enable_mamba_extra_buffer:
        # 获取要保留的 buffer 索引("另一个" buffer)
        mamba_ping_pong_track_buffer_to_keep = (
            self.req_to_token_pool.get_mamba_ping_pong_other_idx(
                req.mamba_next_track_idx
            )
        )
        # 直接使用该 buffer 的 slot 索引,无需复制!
        mamba_value = (
            req.mamba_ping_pong_track_buffer[mamba_ping_pong_track_buffer_to_keep]
            .unsqueeze(-1)
            .clone()
        )
    else:
        # no_buffer 模式:需要 fork
        mamba_value = req.mamba_pool_idx.unsqueeze(-1).clone()
        mamba_value_forked = self.mamba_pool.fork_from(mamba_value)
        mamba_value = mamba_value_forked

    # 插入 Radix Tree
    result = self.insert(InsertParams(key=..., value=..., mamba_value=mamba_value))

    # extra_buffer 模式下,释放"当前" buffer,保留"另一个"给 Radix Tree
    self.req_to_token_pool.free_mamba_cache(
        req,
        mamba_ping_pong_track_buffer_to_keep=mamba_ping_pong_track_buffer_to_keep,
    )

关键区别

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
┌─────────────────────────────────────────────────────────────────────────┐
│                 extra_buffer vs no_buffer: cache_finished_req           │
└─────────────────────────────────────────────────────────────────────────┘

extra_buffer:
    请求的 track_buffer: [slot B, slot C]
    当前使用: slot C (mamba_next_track_idx = 1)

    缓存时:
    ├─ mamba_ping_pong_track_buffer_to_keep = 0 (slot B)
    ├─ Radix Tree 直接使用 slot B(无需复制)
    └─ 释放 slot C,请求结束

no_buffer:
    请求使用: slot A (mamba_pool_idx)

    缓存时:
    ├─ fork_from(slot A) → slot D(分配新 slot + 复制状态)
    ├─ Radix Tree 使用 slot D
    └─ 释放 slot A

cache_unfinished_req:Chunk 中间缓存

对于长请求的 Chunk 边界缓存,逻辑类似:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def cache_unfinished_req(self, req: Req, chunked=False):
    if self.enable_mamba_extra_buffer:
        mamba_ping_pong_track_buffer_to_keep = (
            self.req_to_token_pool.get_mamba_ping_pong_other_idx(
                req.mamba_next_track_idx
            )
        )
        mamba_value = (
            req.mamba_ping_pong_track_buffer[mamba_ping_pong_track_buffer_to_keep]
            .unsqueeze(-1)
            .clone()
        )
    else:
        mamba_value = self.req_to_token_pool.get_mamba_indices(req.req_pool_idx)
        mamba_value_forked = self.mamba_pool.fork_from(mamba_value)
        mamba_value = mamba_value_forked

    result = self.insert(...)

    # 注意:这里不释放请求的 mamba slot!
    # extra_buffer: 请求继续使用另一个 buffer
    # no_buffer: 请求继续使用原始 slot

完整流程图

短请求流程(长度 < mamba_cache_chunk_size)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
┌─────────────────────────────────────────────────────────────────────────┐
│          短请求处理流程(extra_buffer 模式,无 Chunk 边界)              │
└─────────────────────────────────────────────────────────────────────────┘

1. 请求初始化

   ├─ alloc() 分配:
   │   ├─ mamba_pool_idx = slot A
   │   └─ mamba_ping_pong_track_buffer = [slot B, slot C]
   │   └─ mamba_next_track_idx = 0


2. Forward 计算

   ├─ 使用 slot A 计算,更新 Mamba 状态
   ├─ 长度 < mamba_cache_chunk_size,不触发追踪


3. cache_finished_req()

   ├─ mamba_ping_pong_track_buffer_to_keep = 1 (slot C)
   │   └─ get_mamba_ping_pong_other_idx(0) = 1

   ├─ Radix Tree 接管 slot C
   │   └─ insert(mamba_value=slot C)

   ├─ free_mamba_cache():
   │   ├─ 释放 slot A (mamba_pool_idx)
   │   └─ 释放 slot B (track_buffer[0])
   │   └─ 保留 slot C (已在 Radix Tree)


4. 后续请求匹配

   └─ 找到 slot C → COW 复制到新请求空间

长请求流程(Chunked Prefill)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
┌─────────────────────────────────────────────────────────────────────────┐
│       长请求处理流程(extra_buffer 模式,mamba_cache_chunk_size=512)   │
│       请求长度 1200                                                      │
└─────────────────────────────────────────────────────────────────────────┘

1. 请求初始化

   ├─ mamba_pool_idx = slot A
   ├─ mamba_ping_pong_track_buffer = [slot B, slot C]
   └─ mamba_next_track_idx = 0

2. Chunk 1: tokens[0:512]

   ├─ Forward 计算,更新 slot A

   ├─ 追踪点 512:
   │   ├─ 状态写入 track_buffer[0] (slot B)
   │   └─ mamba_next_track_idx = 1

   └─ cache_unfinished_req():
       ├─ mamba_ping_pong_track_buffer_to_keep = 0
       ├─ Radix Tree 接管 slot B
       ├─ 请求保留 [None, slot C](slot B 已转移)
       └─ 不释放 slot A 和 slot C

3. Chunk 2: tokens[512:1024]

   ├─ Forward 计算,更新 slot A

   ├─ 追踪点 1024:
   │   ├─ 状态写入 track_buffer[1] (slot C)
   │   └─ mamba_next_track_idx = 0

   └─ cache_unfinished_req():
       ├─ mamba_ping_pong_track_buffer_to_keep = 1
       ├─ Radix Tree 接管 slot C
       └─ 请求保留 [None, None](两个 buffer 都已转移)

4. Chunk 3: tokens[1024:1200]

   ├─ Forward 计算,更新 slot A
   ├─ 没有新的追踪点

   └─ cache_finished_req():
       ├─ mamba_ping_pong_track_buffer_to_keep = 1 (上次的值)
       ├─ 但 track_buffer[1] 已经是 None(已转移)
       ├─ 使用 slot A 的状态(需要 fork)
       └─ 释放所有剩余 slot

5. Radix Tree 最终状态

   ├─ 节点 [0:512]: mamba_value = slot B
   ├─ 节点 [0:1024]: mamba_value = slot C
   └─ 节点 [0:1200]: mamba_value = slot D (fork 自 slot A)

实现细节:Ping-Pong Buffer 的生命周期

分配时机

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# memory_pool.py: alloc()
def alloc(self, reqs: List["Req"]):
    for req in reqs:
        # 总是分配 mamba_pool_idx
        mid = self.mamba_pool.alloc(1)
        req.mamba_pool_idx = mid[0]

        # extra_buffer 模式额外分配追踪 buffer
        if self.enable_mamba_extra_buffer:
            if req.mamba_ping_pong_track_buffer is None:
                req.mamba_ping_pong_track_buffer = self.mamba_pool.alloc(
                    self.mamba_ping_pong_track_buffer_size  # 通常为 2
                )
                req.mamba_next_track_idx = 0

释放逻辑

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# memory_pool.py: free_mamba_cache()
def free_mamba_cache(self, req, mamba_ping_pong_track_buffer_to_keep=None):
    # 总是释放 mamba_pool_idx
    self.mamba_pool.free(req.mamba_pool_idx)
    req.mamba_pool_idx = None

    if self.enable_mamba_extra_buffer:
        buffer_to_free = self.req_index_to_mamba_ping_pong_track_buffer_mapping[req.req_pool_idx]

        if mamba_ping_pong_track_buffer_to_keep is not None:
            # 保留一个 buffer 给 Radix Tree,释放另一个
            idx_to_free = 1 - mamba_ping_pong_track_buffer_to_keep
            buffer_to_free = buffer_to_free[idx_to_free : idx_to_free + 1]

        self.mamba_pool.free(buffer_to_free)

总结

extra_buffer 模式通过 Ping-Pong 缓冲机制,将状态追踪与状态计算解耦:

设计要点 说明
Ping-Pong Buffer 两个 slot 交替使用,避免覆盖正在使用的状态
所有权转移 缓存时直接转移 buffer 所有权,无需复制
追踪粒度 Prefill: mamba_cache_chunk_size,Decode: mamba_track_interval
内存-性能权衡 以 3 倍内存开销换取零复制性能