mamba-radix-cache-no-buffer

SGLang Mamba Radix Cache 深度解析:no_buffer 模式下的状态管理

no_buffer第一次引入来源于commit a55cf5304 #11214

在这个commit中贴出了对应的实现文档:

https://docs.google.com/document/d/1ZYBsZaz58NIIFmlOuQXqKPruIjv1ZJIUUrIj-sZlT4U/edit?tab=t.0

说明文档

V0 uses LRU list for mamba cache/full attention for eviction like SWARadixCache. It currently only inserts request mamba cache into radix tree when do cache_finished_req / cache_unfinished_req. It can be modified to Marconi or k steps later

cache_finished_req / cache_unfinished_req:

  1. copy a request mamba state, and insert it into mamba radix tree (since mamba state is update in place)
  2. if mamba value has already existed, release this copied mamba value

insert:

  1. the same as radix tree, update full attention LRU value from the path.
  2. when split node, new_node should have be inserted into mamba LRU and full attention LRU
  3. when inserted place is a node with mamba tombstone, insert it into mamba LRU
  4. otherwise, mamba value has already existed.
  • Tombstone:指 mamba_value=None 的节点(只有 KV cache,没有 mamba状态)

match:

  1. return the longest node from root node which both mamba value is not None
  2. reset full attention LRU from this node to root path
  3. reset mamba LRU only for selected node
  4. when a request needs matched node to do prefill, it needs copy this mamba state from tree to its mamba state.

evict:

  1. be used to evict full attention and evict mamba as well. only evict from leaves.
  2. remove leaf node from full attention and mamba LRU
  3. iteratively remove parent node until its mamba value is not None

evict_mamba:

  1. be used to evict mamba, can evict from any node
  2. for parent node, remove this node from LRU and release it from mamba pool
  3. for leaf node: the same as evict logic

inc_lock_ref:

  1. for full attention, lock from node to root
  2. for mamba, only lock itself if it has mamba value

dec_lock_ref:

  1. for full attention, unlock from node to root
  2. for mamba, only unlock itself if it has mamba value

背景

Qwen3-Next 等混合架构模型同时包含 Linear Attention (Mamba) 和 Softmax Attention。对于 Linear Attention 的前缀匹配,需要缓存 SSM 状态和 Conv 状态。

本文深入解析 no_buffer 模式下的缓存分配、索引、更新和生命周期管理机制。

完整流程图

短请求流程(长度 ≤ chunked_prefill_size)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
┌─────────────────────────────────────────────────────────────────────────┐
│                    短请求处理流程(无 chunk)                           │
└─────────────────────────────────────────────────────────────────────────┘

1. 请求到达

   ├─ alloc() → 分配 mamba_pool_idx (假设为 slot 5)
   │   └─ MambaPool.alloc(1) → free_slots 取出索引,清零状态


2. 前缀匹配 match_prefix()

   ├─ 遍历 Radix Tree,找到最长匹配节点
   │   └─ 节点存储: mamba_value = slot_index (如 slot 3)

   ├─ COW: 复制匹配状态到请求空间
   │   └─ copy_from(src=3, dst=5)  # 复制 slot 3 → slot 5


3. Forward 计算

   ├─ Conv: causal_conv1d_fn(update=True)
   │   └─ conv_state[layer, 5] 被更新

   ├─ SSM: mamba_chunk_scan_combined()
   │   └─ ssm_state[layer, 5] 被更新


4. 请求完成 cache_finished_req()

   ├─ fork_from(5) → 分配新 slot (假设为 slot 8)
   │   └─ copy_from(5 → 8)  # 复制最终状态

   ├─ insert() → 将 slot 8 插入 Radix Tree
   │   └─ TreeNode.mamba_value = tensor([8])

   ├─ free_mamba_cache(req) → 释放 slot 5
   │   └─ free_slots.append(5)


5. 后续请求匹配

   └─ match_prefix() → 找到 slot 8 → COW 到新请求空间

长请求流程(长度 > chunked_prefill_size)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
┌─────────────────────────────────────────────────────────────────────────┐
│               长请求处理流程(Chunked Prefill)                         │
│               请求长度 10k,chunk_size = 8192                           │
└─────────────────────────────────────────────────────────────────────────┘

1. 请求到达

   ├─ alloc() → 分配 mamba_pool_idx = slot 5


2. 前缀匹配 match_prefix()

   └─ 假设找到 slot 3,COW: copy_from(3 → 5)


3. Chunk 1: tokens[0:8192]

   ├─ Forward 计算,更新 slot 5 的状态

   ├─ cache_unfinished_req()
   │   │
   │   ├─ fork_from(5) → slot 8
   │   ├─ insert(slot 8) → Radix Tree 在 8192 边界点
   │   │   └─ TreeNode(key=[0:8192], mamba_value=tensor([8]))
   │   │
   │   └─ 不释放 slot 5!请求继续使用


4. Chunk 2: tokens[8192:10000]

   ├─ 使用 slot 5 继续计算(状态从 chunk 1 结束处继续)

   ├─ Forward 计算,更新 slot 5 的状态

   └─ cache_finished_req()

       ├─ fork_from(5) → slot 9
       ├─ insert(slot 9) → Radix Tree 完整序列
       │   └─ TreeNode(key=[0:10000], mamba_value=tensor([9]))

       └─ free_mamba_cache(req) → 释放 slot 5


5. Radix Tree 最终状态

   ├─ 节点 [0:8192]: mamba_value = slot 8 (chunk 边界)
   └─ 节点 [0:10000]: mamba_value = slot 9 (完整序列)

关键点:Chunk 边界的中间状态(slot 8)允许后续请求在 chunk 边界处进行前缀匹配,即使原始请求尚未完全结束。

核心数据结构

1. MambaPool:状态存储的核心

Mamba 状态存储在 MambaPool 中(memory_pool.py:188-386),包含两种状态:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class MambaPool:
    def __init__(self, ...):
        # Conv State: 每层一个 tensor
        # qwen3-next是分别对qkv做卷积,这里把他们打包在conv_dim一起存储了,所以没有体现出来
        # 形状: [num_mamba_layers, size+1, conv_dim/tp, conv_kernel-1]
        conv_state = [
            torch.zeros(
                size=(num_mamba_layers, size + 1) + conv_shape,
                dtype=conv_dtype,  # 默认 bfloat16
                device=device,
            )
            for conv_shape in conv_state_shape  # 通常是 [(conv_dim/tp, conv_kernel-1)]
        ]

        # SSM State (Temporal State): 单个 tensor
        # 形状: [num_mamba_layers, size+1, num_heads/tp, head_dim, state_size]
        temporal_state = torch.zeros(
            size=(num_mamba_layers, size + 1) + temporal_state_shape,
            dtype=ssm_dtype,  # 默认 float32
            device=device,
        )
1
2
3
4
5
6
7
8
9
10
11
┌─────────────────────────────────────────────────────────────────────────┐
│                            MambaPool                                    │
├─────────────────────────────────────────────────────────────────────────┤
│  conv_state[i]: [num_layers, size+1, conv_dim/tp, conv_kernel-1]        │
│                 ↓ Conv 状态,每层独立,存储卷积核的历史输入             │
│                                                                         │
│  temporal_state: [num_layers, size+1, num_heads/tp, head_dim, dstate]   │
│                   ↓ SSM 状态,每层独立,存储隐状态 h                    │
│                                                                         │
│  free_slots: 可用的 slot 索引列表                                       │
└─────────────────────────────────────────────────────────────────────────┘

以 Qwen3-Next 为例

  • conv_state_shape = [(conv_dim/tp, 3)] — 卷积核大小通常为 4,历史长度为 3
  • temporal_state_shape = (num_heads/tp, 128, 128) — 32 头,head_dim=128,state_size=128

2. 索引管理:请求到状态的映射

1
2
3
4
HybridReqToTokenPool
├── req_to_token: [size, max_context_len] → KV cache indices
├── req_index_to_mamba_index_mapping: [size] → MambaPool slot indices
└── mamba_pool: MambaPool 实例

每个请求通过 req.mamba_pool_idx 持有一个 Mamba slot index,用于访问该请求的 Conv 和 SSM 状态。

3. Radix Tree 节点结构

1
2
3
4
class TreeNode:
    key: RadixKey              # token 序列
    value: torch.Tensor        # KV cache indices (每个 token 对应的索引)
    mamba_value: torch.Tensor  # Mamba slot index (单个索引,指向 MambaPool)

mamba_value 是一个 单索引,因为 Mamba 状态是序列级别的,不像 KV cache 是 token 级别的。

缓存分配流程

请求初始化

当新请求到达时,HybridReqToTokenPool.alloc() 分配 Mamba slot:

1
2
3
4
5
6
7
8
9
10
11
12
def alloc(self, reqs: List["Req"]):
    for req in reqs:
        if req.mamba_pool_idx is not None:
            # 已有索引(来自 radix cache 前缀匹配)
            mid = req.mamba_pool_idx
        else:
            # 分配新 slot
            mid = self.mamba_pool.alloc(1)
            req.mamba_pool_idx = mid[0]

    # 建立映射
    self.req_index_to_mamba_index_mapping[select_index] = mamba_indices

MambaPool.alloc() 的实现:

1
2
3
4
5
6
7
8
9
10
def alloc(self, need_size: int):
    select_index = self.free_slots[:need_size]
    self.free_slots = self.free_slots[need_size:]

    # 分配时清零,避免脏数据
    for conv_tensor in self.mamba_cache.conv:
        conv_tensor[:, select_index] = 0
    self.mamba_cache.temporal[:, select_index] = 0

    return select_index

Forward 中的状态更新

MambaMixer2.forward() 中,状态通过索引被访问和更新:

Prefill 阶段

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 获取状态索引
state_indices_tensor = metadata.mamba_cache_indices  # [num_reqs]
conv_state = layer_cache.conv[0]
ssm_state = layer_cache.temporal

# Conv 计算,同时更新 conv_state
hidden_states_B_C = causal_conv1d_fn(
    x, conv_weights, bias,
    conv_states=conv_state,              # 被更新
    cache_indices=state_indices_tensor,  # 每个请求的 slot index
    query_start_loc=query_start_loc,
)

# SSM 计算,返回 final state
varlen_state = mamba_chunk_scan_combined(
    hidden_states, dt, A, B, C, ...
    initial_states=initial_states,  # 来自前缀匹配
    return_varlen_states=True,
)

# 更新 ssm_state
ssm_state[state_indices_tensor] = varlen_state

Decode 阶段

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Conv 增量更新(单 token)
hidden_states_B_C = causal_conv1d_update(
    hidden_states_B_C,
    conv_state,  # in-place 更新
    conv_weights, bias,
    conv_state_indices=state_indices_tensor,
)

# SSM 增量更新(单 token)
selective_state_update(
    ssm_state,  # in-place 更新
    hidden_states, dt, A, B, C, D,
    state_batch_indices=state_indices_tensor,
    out=output,
)

Radix Tree 缓存管理

前缀匹配

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def _match_prefix_helper(self, key: RadixKey):
    while len(key) > 0 and child_key in node.children.keys():
        child = node.children[child_key]

        # 如果当前节点有 mamba_value,更新最佳匹配
        if node.mamba_value is not None:
            best_value_len = len(value)
            best_last_node = node

        # 处理 split 情况(部分匹配)
        if prefix_len < len(child.key):
            new_node = self._split_node(child.key, child, prefix_len)
            if child.mamba_value is not None:
                best_value_len = len(value)
                best_last_node = child
            break

        value.append(child.value)
        node = child

匹配后状态复制(Copy-On-Write)

1
2
3
4
5
6
7
8
if cow_mamba and last_node.mamba_value is not None:
    if req.mamba_pool_idx is None:
        dst_index = self.mamba_pool.alloc(1)
    else:
        dst_index = req.mamba_pool_idx.unsqueeze(0)

    src_index = last_node.mamba_value  # radix tree 中的 slot
    self.mamba_pool.copy_from(src_index, dst_index)  # 复制状态

请求完成后缓存(cache_finished_req)

关键问题:为什么需要 fork_from 而不是直接转移所有权?

1
2
3
4
5
6
7
8
9
10
def cache_finished_req(self, req: Req, is_insert: bool = True):
    # Fork: 分配新 slot 并复制状态
    mamba_value = req.mamba_pool_idx.unsqueeze(-1).clone()
    mamba_value_forked = self.mamba_pool.fork_from(mamba_value)

    # 插入 radix tree
    result = self.insert(..., mamba_value=mamba_value_forked)

    # 释放请求的原始 slot
    self.req_to_token_pool.free_mamba_cache(req)

Chunk 中间缓存(cache_unfinished_req)

当请求长度超过 chunked_prefill_size(默认 8192)时,请求会被分成多个 chunk 处理。每个 chunk 完成后,会调用 cache_unfinished_req 将当前状态缓存到 Radix Tree。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def cache_unfinished_req(self, req: Req, chunked=False):
    token_ids = req.fill_ids
    cache_len = len(token_ids)  # no_buffer 模式下缓存整个 chunk

    # 获取当前请求的 mamba slot index
    mamba_value = self.req_to_token_pool.get_mamba_indices(
        req.req_pool_idx
    ).unsqueeze(-1)

    # Fork 到新 slot(与 cache_finished_req 相同)
    mamba_value_forked = self.mamba_pool.fork_from(mamba_value)

    # 插入 radix tree
    result = self.insert(
        InsertParams(
            key=RadixKey(page_aligned_token_ids, req.extra_key),
            value=page_aligned_kv_indices,
            mamba_value=mamba_value_forked,
        )
    )

    # 关键区别:不释放原始 mamba slot!
    # 请求还在进行中,需要继续使用

    # 更新请求的 last_node 用于后续 chunk
    req.last_node = new_last_node

cache_unfinished_reqcache_finished_req 的关键区别

操作 cache_unfinished_req cache_finished_req
Fork 状态
插入 Radix Tree
释放原始 slot ✗(请求继续) ✓(请求结束)
更新 req.last_node ✓(后续 chunk 使用)

为什么 Chunk 模式下需要 cache_unfinished_req

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
请求长度 10k > chunked_prefill_size (8192)

Chunk 1: tokens[0:8192]

    ├─ Forward 计算,更新 slot 5 的状态

    └─ cache_unfinished_req()
        ├─ fork_from(5) → slot 8
        ├─ insert(slot 8) → Radix Tree 在 8192 边界点有状态
        └─ 不释放 slot 5,请求继续

Chunk 2: tokens[8192:10000]

    ├─ 使用 slot 5 继续计算(从 chunk 1 的最终状态开始)

    └─ cache_finished_req()
        ├─ fork_from(5) → slot 9
        ├─ insert(slot 9) → Radix Tree 在完整序列有状态
        └─ 释放 slot 5

这样设计的好处是:后续请求如果前缀在 chunk 边界(如 8192),可以直接匹配到 Radix Tree 中的状态,而不需要等待原始请求完全结束。

关键数据维度总结

组件 形状 说明
MambaPool.conv[i] [L, size+1, conv_dim/tp, K-1] Conv 状态,每层独立
MambaPool.temporal [L, size+1, H/tp, head_dim, dstate] SSM 状态,每层独立
req.mamba_pool_idx 标量 int 请求的 mamba slot index
TreeNode.mamba_value [1] tensor Radix Tree 节点的 slot index
state_indices_tensor [num_reqs] 批量状态索引(用于 forward)

总结

SGLang 的 Mamba 缓存机制通过以下设计实现了高效的状态管理:

  1. Pool + Index 架构:集中存储状态,通过索引间接访问,支持高效的状态复制和转移
  2. Radix Tree 前缀匹配:利用 token 序列的公共前缀,减少重复计算
  3. Copy-On-Write:前缀匹配时复制状态,保证不同请求的状态独立
  4. Fork 机制:请求完成时 fork 状态到新 slot,确保 Radix Tree 拥有独立副本,生命周期独立管理

这套机制在 no_buffer 模式下尤为关键,因为没有额外的 ping-pong buffer 来追踪状态变化,必须正确管理每个 slot 的所有权和生命周期。

当前存在的bug

WARNING

注意,当前的代码实现,nobuffer是有bug的,对于请求长度小于chunked prefill,并不能正确缓存下来 假如chunk size是8k,第一次输入1k数据,完成后再次请求相同请求,cached_tokens是0 而如果请求是9k,第二次时cached_tokens是8k,也就是会在chunked prefill的边界缓存下来 细节来说就是cache_finished_req有bug,而cache_unfinished_req是正确的

另外还有一个问题:no_buffer模式没有处理FLA的64位对齐要求,即使修复了上述bug, 缓存的SSM状态位置可能不是64的倍数,导致状态恢复时可能出错。

然后对于decode来说,在nobuffer下并没有实现缓存前缀,在extra-buffer实现了decode间隔缓存

测试环境

启动服务:

1
2
3
4
python3 -m sglang.launch_server --host 0.0.0.0 --port 8055 \
    --model-path /path/to/Qwen3-Next-80B-A3B-Instruct-4layers \
    --mamba-scheduler-strategy no_buffer \
    --enable-metrics --enable-cache-report

现象描述

串行发送两条相同的请求:

请求长度 第一次请求 cached_tokens 第二次请求 cached_tokens
1k 0 0 (期望有命中)
9k 0 8192 (chunked_prefill_size)

关键观察

  • 请求长度 < chunked_prefill_size (默认 8192):没有前缀匹配
  • 请求长度 > chunked_prefill_size:可以匹配到 chunked_prefill_size 边界

Bug 根因分析

问题一:cache_finished_req 没有正确复制 Mamba 状态

源码位置: python/sglang/srt/mem_cache/mamba_radix_cache.py

原始代码 (cache_finished_req 第 549-560 行):

1
2
3
4
5
6
7
8
9
10
11
else:
    mamba_value = req.mamba_pool_idx.unsqueeze(-1).clone()
    mamba_ping_pong_track_buffer_to_keep = None

result = self.insert(
    InsertParams(
        key=RadixKey(token_ids[:page_aligned_len], req.extra_key),
        value=page_aligned_kv_indices,
        mamba_value=mamba_value,
    )
)

问题分析

  1. 只复制了 mamba_pool_idx 的值(tensor),没有复制实际的 mamba 状态数据
  2. 插入 radix tree 后,free_mamba_cache 会释放原始的 mamba slot
  3. Radix tree 中存储的 index 指向已被释放的 slot,数据无效

对比 cache_unfinished_req (chunked prefill 场景):

cache_unfinished_req 正确使用了 fork_from 来复制 mamba 状态。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
else:
    mamba_value = self.req_to_token_pool.get_mamba_indices(
        req.req_pool_idx
    ).unsqueeze(-1)
# 关键:fork_from 会复制状态到新 slot
mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value)

result = self.insert(
    InsertParams(
        key=RadixKey(page_aligned_token_ids, req.extra_key),
        value=page_aligned_kv_indices,
        mamba_value=mamba_value_forked,  # 新 slot 的 index
    )
)

问题二:_match_prefix_helper 在 split 后丢失 mamba 状态引用

源码位置: python/sglang/srt/mem_cache/mamba_radix_cache.py 第 953-993 行

场景分析

当第二个请求匹配时,假设请求长度为 1007(减1用于 logprob),而 radix tree 中节点 key 长度为 1008:

1
2
3
4
5
6
7
8
9
10
11
12
13
while len(key) > 0 and child_key in node.children.keys():
    child = node.children[child_key]
    # 更新 best_value_len
    if node.mamba_value is not None:
        best_value_len = len(value)
        best_last_node = node

    prefix_len = self.key_match_fn(child.key, key)
    if prefix_len < len(child.key):  # 1007 < 1008,触发 split
        new_node = self._split_node(child.key, child, prefix_len)
        value.append(new_node.value)
        node = new_node
        break  # 循环结束,但没有更新 best_value_len!

问题分析

  1. prefix_len=1007 < child.key_len=1008,触发 _split_node
  2. _split_node 创建新节点作为父节点,mamba_value=None
  3. 原 child 变为新节点的子节点,保留 mamba_value=5
  4. 循环 break,但 best_value_len 没有被更新
  5. 最终返回 best_value_len=0,匹配失败

调试日志验证

1
2
[DEBUG] _match_prefix_helper: matched child 7, prefix_len=1007, child.key_len=1008
[DEBUG] _match_prefix_helper: result value_len=1, best_value_len=0, best_last_node=0

问题三:no_buffer 模式未处理 FLA 64 位对齐要求

背景:FLA (Flash Linear Attention) 库要求 SSM 状态以 64 (FLA_CHUNK_SIZE) 为倍数进行缓存。

源码位置: python/sglang/srt/layers/attention/fla/chunk_delta_h.py:19

1
CHUNK_SIZE = 64

问题分析

cache_finished_req 中,no_buffer 模式直接使用完整长度缓存,没有做 FLA 对齐:

1
2
3
4
5
6
# cache_finished_req 第 508-512 行
cache_len = (
    req.mamba_last_track_seqlen  # ✅ extra_buffer: 对齐后的长度
    if self.enable_mamba_extra_buffer
    else len(token_ids)          # ❌ no_buffer: 完整长度,未对齐!
)

extra_buffer 模式的正确处理

1
2
3
4
5
6
7
8
9
# schedule_batch.py - extra_buffer 模式的对齐计算
mamba_cache_chunk_size = get_global_server_args().mamba_cache_chunk_size  # 通常是 64

mamba_track_seqlen_aligned = (
    len(req.prefix_indices)
    + (req.extend_input_len // mamba_cache_chunk_size)
    * mamba_cache_chunk_size  # 向下对齐到 64 的倍数
)
req.mamba_last_track_seqlen = mamba_track_seqlen_aligned

问题场景

1
2
3
4
5
6
7
8
9
10
请求长度 = 100 (不是 64 的倍数)

extra_buffer 模式:
  cache_len = (100 // 64) * 64 = 64 (向下对齐)
  SSM 状态在位置 64 处正确缓存 ✅

no_buffer 模式:
  cache_len = 100 (未对齐)
  SSM 状态在位置 100 处缓存 ❌
  但 FLA 的 SSM 状态只在 64 的倍数位置是有效的!

为什么这是个问题?

根据 hybrid_linear_attn_backend.py 的注释:

1
2
3
4
5
# There are 3 cases for mamba_track_seqlen:
# 1) aligned with FLA_CHUNK_SIZE -> retrieve from last_recurrent_state
#    a) is the last position -> retrieve from last_recurrent_state
#    b) is NOT the last position -> retrieve from h
# 2) unaligned with FLA_CHUNK_SIZE -> retrieve from h

如果缓存的 SSM 状态位置不是 64 的倍数,在恢复时可能无法正确获取状态,导致计算结果错误。

修复方案

1
2
3
4
5
6
7
8
9
# cache_finished_req 修复
mamba_cache_chunk_size = get_global_server_args().mamba_cache_chunk_size

# no_buffer 模式也需要做 FLA 对齐
cache_len = (
    req.mamba_last_track_seqlen
    if self.enable_mamba_extra_buffer
    else (len(token_ids) // mamba_cache_chunk_size) * mamba_cache_chunk_size  # 向下对齐
)

Bug 总结

问题 原因 修复
问题一 cache_finished_req 没有复制 mamba 状态 使用 fork_from 复制状态到新 slot
问题二 _match_prefix_helper split 后丢失 mamba 引用 检查 child 的 mamba_value 并更新 best_value_len
问题三 no_buffer 模式未处理 FLA 64 位对齐 缓存长度向下对齐到 mamba_cache_chunk_size 的倍数

这三个问题共同导致了 no_buffer 模式下短请求无法正确进行前缀匹配。