qwen3next-cache

Qwen3Next Linear Attention Prefix Cache 实现深度分析

目录

  1. 核心概念
  2. 为什么Linear Attention不能做到Token粒度缓存
  3. Chunk粒度缓存机制
  4. Radix Cache的Tombstone机制
  5. Prefix匹配流程
  6. 状态追踪(Tracking)机制
  7. Copy-on-Write实现
  8. 缓存保存流程
  9. 完整示例
  10. 关键数据结构

1. 核心概念

1.1 Linear Attention的递归性质

Linear Attention (Gated Delta Net) 与标准 Attention 的根本区别在于其递归计算特性:

1
2
3
4
5
6
标准 Attention (Token级别并行):
  Output[t] = f(Q[t], K[0:t], V[0:t])  # 每个token独立计算

Linear Attention (递归依赖):
  H[t] = f(H[t-1], Q[t], K[t], V[t])   # 状态递归传递
  Output[t] = g(H[t], Z[t])             # 基于状态生成输出

这意味着:

  • Standard Attention: 可以缓存每个token的 K/V,任意跳过已计算token
  • Linear Attention: 必须从最后一个有效状态开始顺序计算

1.2 FLA Chunk机制

Flash Linear Attention (FLA) 引入了 Chunk 概念来平衡并行度和状态管理:

1
2
3
4
5
# chunk_delta_h.py
CHUNK_SIZE = 64  # 固定chunk大小

# 序列被分成多个chunk:
# [0-63], [64-127], [128-191], ...

Chunk内部:

  • 完全并行计算(类似标准attention)
  • 输出中间状态 h 和最终状态 last_recurrent_state

Chunk之间:

  • 递归传递:h[i] 作为 chunk[i+1]initial_state
  • 只有chunk边界的状态可以被缓存和复用

2. 为什么Linear Attention不能做到Token粒度缓存

2.1 状态依赖链

1
2
3
4
5
6
7
8
9
10
11
12
13
# Linear Attention的状态更新伪代码
def linear_attention(tokens):
    h = initial_state
    outputs = []

    for t in range(len(tokens)):
        # 每一步都依赖前一步的状态
        h = delta_rule(h, tokens[t])
        outputs.append(gate(h, tokens[t]))

    return outputs

# 如果要从位置 n 开始计算,必须有 h[n-1] 的状态

2.2 Chunk边界的必要性

FLA的设计只允许在 Chunk边界 (64的倍数位置) 进行状态分割:

1
2
3
4
5
序列: [0, 1, 2, ..., 63, 64, 65, ..., 127, 128, ...]
      |<- Chunk 0 ->|  |<--- Chunk 1 --->|  |
      ✓ 可缓存         ✓ 可缓存            ✓ 可缓存

位置 65: ✗ 不能作为缓存边界 (不在chunk边界)

原因:

  1. last_recurrent_state 只在chunk末尾输出
  2. 中间位置的state需要从头或从上一个chunk边界开始计算
  3. 计算中间位置的成本 > 直接从头计算

3. Chunk粒度缓存机制

3.1 核心常量

1
2
3
4
5
# chunk_delta_h.py
CHUNK_SIZE = 64  # FLA固定的chunk大小

# mamba_radix_cache.py
from sglang.srt.layers.attention.fla.chunk_delta_h import CHUNK_SIZE as FLA_CHUNK_SIZE

3.2 对齐计算

1
2
3
4
5
6
7
8
9
10
# mamba_radix_cache.py:955-963
# Calculate the branching point. It is defined as the last aligned position that
# does not have a mamba value.
if len(value) > best_value_len:
    fla_chunk_aligned_seqlen = (
        sum(len(v) for v in value) // FLA_CHUNK_SIZE
    ) * FLA_CHUNK_SIZE  # 向下对齐到64的倍数
    mamba_branching_seqlen = (
        fla_chunk_aligned_seqlen if fla_chunk_aligned_seqlen > 0 else None
    )

示例:

1
2
3
4
5
6
7
8
实际匹配长度: 150 tokens
对齐后长度:   96 tokens (150 // 64 * 64 = 2 * 64 = 96)
未匹配部分:   54 tokens (需要重新计算)

为什么是96而不是128?
- 因为匹配路径上的节点在位置96之后有 "tombstone" (mamba_value=None)
- Tombstone表示该位置没有有效的mamba cache
- 必须回退到最后一个有效chunk边界

3.3 状态存储布局

1
2
3
4
5
6
7
8
9
10
11
12
13
# chunk_gated_delta_rule_fwd_h 返回
h: [B, NT, H, K, V]  # NT = num_chunks
  # h[0]: chunk 0 后的状态
  # h[1]: chunk 1 后的状态
  # h[2]: chunk 2 后的状态
  # ...

last_recurrent_state: [N, H, K, V]  # N = batch_size
  # 整个序列的最终状态

# 缓存时只能保存:
# 1. h[i] - chunk边界的状态
# 2. last_recurrent_state - 序列末尾状态

4. Radix Cache的Tombstone机制

4.1 树节点结构

1
2
3
4
5
6
class TreeNode:
    value: Optional[torch.Tensor]        # KV cache索引
    mamba_value: Optional[torch.Tensor]   # Mamba cache索引

    full_lock_ref: int     # KV cache锁计数
    mamba_lock_ref: int    # Mamba cache锁计数

4.2 Tombstone定义

Tombstone: mamba_value = Nonevalue != None 的节点

1
2
3
4
5
# 示例树结构:
Root
├─ [0-95] (value=[...], mamba_value=[...])  # 有完整cache
├─ [96-150] (value=[...], mamba_value=None)  # Tombstone! 只有KV cache
└─ [96-191] (value=[...], mamba_value=[...]) # 另一条路径,有完整cache

为什么需要Tombstone?

  • Linear attention和Full attention的缓存粒度不同
  • KV cache (Full attention): token粒度,可以精确到任意token
  • Mamba cache (Linear attention): chunk粒度 (64的倍数)
  • 允许两者独立管理,提高缓存利用率

4.3 分离驱逐策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# mamba_radix_cache.py

def evict_mamba(self, mamba_num: int):
    """只驱逐mamba cache,保留KV cache"""
    while need_evict:
        if len(x.children) > 0:
            # 内部节点: 只释放mamba,变成tombstone
            self.req_to_token_pool.mamba_pool.free(x.mamba_value)
            self._tombstone_internal_node(x)  # mamba_value = None
        else:
            # 叶子节点: 同时释放两者
            self._evict_leaf_node(x, True)

def evict(self, full_num_tokens: int):
    """驱逐KV cache (叶子节点),同时释放mamba"""
    # 只能驱逐叶子节点
    # 同时释放 value 和 mamba_value

5. Prefix匹配流程

5.1 MatchPrefix算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# mamba_radix_cache.py:900-967
def _match_prefix_helper(self, key: RadixKey) -> Tuple[List, TreeNode, Optional[int]]:
    """
    返回: (matched_value_list, last_node, mamba_branching_seqlen)

    核心逻辑:
    1. 遍历radix树,找到最长匹配路径
    2. 记录最后一个有 mamba_value 的节点
    3. 计算mamba_branching_seqlen
    """
    node = self.root_node
    value = []
    best_value_len = 0
    best_last_node = node

    # 遍历树
    while len(key) > 0 and child_key in node.children.keys():
        child = node.children[child_key]

        # 更新最佳节点: 当前节点必须有mamba_value
        if node.mamba_value is not None:
            best_value_len = len(value)
            best_last_node = node

        # ... 继续匹配

    # 计算branching point
    if len(value) > best_value_len:
        # 向下对齐到chunk边界
        fla_chunk_aligned_seqlen = (
            sum(len(v) for v in value) // FLA_CHUNK_SIZE
        ) * FLA_CHUNK_SIZE
        mamba_branching_seqlen = (
            fla_chunk_aligned_seqlen if fla_chunk_aligned_seqlen > 0 else None
        )
    else:
        mamba_branching_seqlen = None

    return value[:best_value_len], best_last_node, mamba_branching_seqlen

5.2 匹配示例

场景: 请求的token序列与缓存有部分匹配

1
2
3
4
5
6
7
请求序列: "A B C D E F G H I J K L M N ..."
缓存1:    "A B C D E F" (150 tokens, mamba_value存在)
         ↓ 匹配 150 tokens
         但只有前 96 个是chunk对齐的

实际使用: 96 tokens (从缓存)
需要计算: 54 tokens (新计算)

流程:

  1. 遍历树,找到匹配节点: [0-150]mamba_value
  2. 计算 mamba_branching_seqlen = 96 (向下对齐)
  3. 返回:
    • device_indices: 96个token的KV cache索引
    • mamba_value: mamba cache索引
    • mamba_branching_seqlen: 96

6. 状态追踪(Tracking)机制

6.1 为什么需要Tracking?

Prefill阶段,计算完新的chunk后,需要将其状态保存到cache中。但需要追踪:

  1. Conv State: 卷积窗口的最后 conv_kernel_size 个输入
  2. SSM State: 每个chunk边界后的递归状态

6.2 Track相关字段

1
2
3
4
5
6
7
8
9
10
# ForwardBatch (schedule_batch.py)
mamba_track_indices: torch.Tensor    # [b] 追踪目标cache slot
mamba_track_mask: torch.Tensor       # [b] 是否需要追踪
mamba_track_seqlens: torch.Tensor    # [b] 追踪的序列长度

# Req (schedule_batch.py)
mamba_ping_pong_track_buffer: Tensor  # [2] ping-pong缓冲区
mamba_next_track_idx: int             # 下一个使用的buffer索引
mamba_last_track_seqlen: Optional[int] # 上次追踪的seq_len
mamba_branching_seqlen: Optional[int]  # 分支点seq_len

6.3 Track计算流程

6.3.1 Prepare阶段

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# schedule_batch.py:1612-1664
def _mamba_radix_cache_v2_req_prepare_for_extend(self, req, ...):
    # 计算是否需要追踪
    mask = (req.extend_input_len // FLA_CHUNK_SIZE) * FLA_CHUNK_SIZE > 0
    mamba_track_mask_cpu.append(mask)

    if mask:
        # 完整追踪长度
        mamba_track_seqlen = len(req.prefix_indices) + req.extend_input_len

        # 对齐后的追踪长度
        mamba_track_seqlen_aligned = (
            len(req.prefix_indices)
            + (req.extend_input_len // FLA_CHUNK_SIZE) * FLA_CHUNK_SIZE
        )

        # 处理分支点
        if req.mamba_branching_seqlen is not None:
            branching_seqlen_aligned_mask = (
                req.mamba_branching_seqlen - len(req.prefix_indices)
            ) % FLA_CHUNK_SIZE == 0

            if (req.mamba_branching_seqlen > len(req.prefix_indices)
                and req.mamba_branching_seqlen < mamba_track_seqlen
                and branching_seqlen_aligned_mask):
                # 需要追踪分支点
                mamba_track_seqlen = req.mamba_branching_seqlen + 1
                mamba_track_seqlen_aligned = req.mamba_branching_seqlen

        req.mamba_last_track_seqlen = mamba_track_seqlen_aligned

6.3.2 Conv State追踪索引计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# hybrid_linear_attn_backend.py:259-298
def _init_track_conv_indices(self, query_start_loc, forward_batch):
    """
    计算需要追踪的卷积状态的位置

    Conv state存储最后 conv_kernel_size 个输入
    需要从输入序列中提取这些位置的数据
    """
    conv_state_len = self.conv_states_shape[-1]  # conv_kernel_size

    # 计算对齐长度
    lens_to_track = forward_batch.mamba_track_seqlens - forward_batch.extend_prefix_lens
    aligned_len = (lens_to_track // FLA_CHUNK_SIZE) * FLA_CHUNK_SIZE

    # 起始位置 = 对齐长度 - conv_kernel_size
    start_indices = query_start_loc[:-1] + aligned_len - conv_state_len
    start_indices = start_indices[forward_batch.mamba_track_mask]

    # 创建索引: [num_tracked, conv_kernel_size]
    indices = start_indices.unsqueeze(-1) + torch.arange(conv_state_len, ...)

    return indices.clamp(0, query_start_loc[-1] - 1)

示例:

1
2
3
4
5
6
7
8
9
10
11
extend_prefix_lens = 100 (已有cache)
mamba_track_seqlens = 300 (追踪到300)
extend_input_len = 210 (新输入210个tokens)

lens_to_track = 300 - 100 = 200
aligned_len = (200 // 64) * 64 = 192

如果 conv_kernel_size = 4:
start_indices = query_start_loc[:-1] + 192 - 4 = query_start_loc[:-1] + 188

提取位置: [188, 189, 190, 191] (最后4个输入)

6.3.3 SSM State追踪索引计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# hybrid_linear_attn_backend.py:300-378
def _init_track_ssm_indices(self, mamba_cache_indices, forward_batch):
    """
    计算SSM状态的追踪索引

    FLA kernel输出:
    - h: [NT, H, K, V] - 每个chunk边界的状态
    - last_recurrent_state: [N, H, K, V] - 序列末尾状态

    需要决定从h还是last_recurrent_state提取状态
    """
    # 计算chunk数量
    num_h_states = (extend_seq_lens - 1) // FLA_CHUNK_SIZE + 1

    # 计算偏移
    track_ssm_src_offset = torch.zeros_like(num_h_states)
    track_ssm_src_offset[1:] = torch.cumsum(num_h_states[:-1], dim=0)

    # 过滤需要追踪的请求
    lens_to_track = mamba_track_seqlens - prefix_lens
    lens_masked = lens_to_track[mamba_track_mask]
    offset_masked = track_ssm_src_offset[mamba_track_mask]
    dst_masked = mamba_track_indices[mamba_track_mask]

    # 判断是否对齐
    is_aligned = (lens_masked % FLA_CHUNK_SIZE) == 0

    # Case 1: 对齐 - 从 last_recurrent_state
    track_ssm_final_src = mamba_cache_indices[mamba_track_mask][is_aligned]
    track_ssm_final_dst = dst_masked[is_aligned]

    # Case 2: 不对齐 - 从 h (中间状态)
    not_aligned = ~is_aligned
    track_ssm_h_src = offset_masked[not_aligned] + (
        lens_masked[not_aligned] // FLA_CHUNK_SIZE
    )
    track_ssm_h_dst = dst_masked[not_aligned]

    return track_ssm_h_src, track_ssm_h_dst, track_ssm_final_src, track_ssm_final_dst

示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
extend_seq_lens = 210
prefix_lens = 100
lens_to_track = 110

num_h_states = (210 - 1) // 64 + 1 = 4
chunk分布:
- h[0]: 对应位置 [0-63]
- h[1]: 对应位置 [64-127]
- h[2]: 对应位置 [128-191]
- last_recurrent_state: 对应位置 [0-210]

lens_masked = 110
is_aligned = (110 % 64 == 0) = False

需要从 h 获取:
110 // 64 = 1
从 h[1 + offset] 获取状态

7. Copy-on-Write实现

7.1 为什么需要CoW?

当多个请求共享同一个前缀时,如果其中某个请求需要继续计算,需要复制Mamba cache:

1
2
3
4
5
6
7
8
9
10
11
请求A: "Hello world" [已缓存]
请求B: "Hello world, how are you?" [复用A的cache]

问题:
- A和B共享同一个 mamba_value
- 如果B继续decode,会修改cache
- A的cache会被破坏!

解决: Copy-on-Write
- B fork 一份独立的cache
- 两者互不干扰

7.2 CoW实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# mamba_radix_cache.py:440-459
def match_prefix(self, key, cow_mamba=False, req=None):
    value, last_node, mamba_branching_seqlen = self._match_prefix_helper(key)

    # Copy-on-Write for Mamba
    if cow_mamba and last_node.mamba_value is not None:
        if req.mamba_pool_idx is None:
            # 分配新的cache slot
            dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
            if dst_index is None:
                # 空间不足,驱逐后重试
                self.inc_lock_ref(last_node)
                self.evict_mamba(1)
                dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
                self.dec_lock_ref(last_node)

            # 复制状态
            src_index = last_node.mamba_value
            self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)
            req.mamba_pool_idx = dst_index
        else:
            # 已有独立cache,更新
            src_index = last_node.mamba_value
            dst_index = req.mamba_pool_idx.unsqueeze(0)
            self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)

    return MatchResult(device_indices=torch.cat(value), ...)

7.3 Fork实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# memory_pool.py:304-309
def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]:
    """Fork一个新的mamba cache副本"""
    dst_index = self.alloc(1)
    if dst_index == None:
        return None
    self.copy_from(src_index, dst_index)
    return dst_index

def copy_from(self, src_index, dst_index):
    """复制conv和temporal状态"""
    for i in range(len(self.mamba_cache.conv)):
        self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][:, src_index]
    self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[:, src_index]

7.4 使用场景

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# mamba_radix_cache.py:576-689
def cache_unfinished_req(self, req: Req):
    """
    缓存未完成的请求 (chunked prefill)
    """
    # Fork mamba cache
    mamba_value = self.req_to_token_pool.get_mamba_indices(req.req_pool_idx)
    mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value)

    # 如果fork失败,驱逐后重试
    if mamba_value_forked is None:
        self.evict_mamba(1)
        mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value)

    # 插入新的树节点
    new_prefix_len, mamba_exist = self.insert(
        RadixKey(page_aligned_token_ids, req.extra_key),
        page_aligned_kv_indices,
        mamba_value_forked,
    )

    # 如果已有mamba cache,释放fork的副本
    if mamba_exist:
        self.req_to_token_pool.mamba_pool.free(mamba_value_forked)

8. 缓存保存流程

8.1 Cache Finished Req

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# mamba_radix_cache.py:480-572
def cache_finished_req(self, req: Req, is_insert: bool = True):
    # 获取已提交的KV长度
    kv_committed_len = req.pop_committed_kv_cache()

    token_ids = (req.origin_input_ids + req.output_ids)[:kv_committed_len]
    kv_indices = self.req_to_token_pool.req_to_token[
        req.req_pool_idx, :kv_committed_len
    ]

    if is_insert:
        # 计算实际缓存长度
        cache_len = (
            req.mamba_last_track_seqlen
            if self.enable_mamba_extra_buffer
            else len(token_ids)
        )

        # 对齐到page_size
        if self.page_size != 1:
            page_aligned_len = len(kv_indices) // self.page_size * self.page_size
            page_aligned_kv_indices = kv_indices[:page_aligned_len]
            page_aligned_token_ids = token_ids[:page_aligned_len]

        # 从ping-pong buffer获取mamba cache
        if self.enable_mamba_extra_buffer:
            mamba_ping_pong_track_buffer_to_keep = (
                self.req_to_token_pool.get_mamba_ping_pong_other_idx(
                    req.mamba_ping_pong_track_buffer
                )
            )
            mamba_value = (
                req.mamba_ping_pong_track_buffer[
                    mamba_ping_pong_track_buffer_to_keep
                ].unsqueeze(-1).clone()
            )
        else:
            mamba_value = req.mamba_pool_idx.unsqueeze(-1).clone()

        # 插入radix树
        new_prefix_len, mamba_exist = self.insert(
            RadixKey(page_aligned_token_ids, req.extra_key),
            page_aligned_kv_indices,
            mamba_value,
        )

        # 释放未使用的KV cache
        self.token_to_kv_pool_allocator.free(
            kv_indices[req.cache_protected_len : new_prefix_len]
        )

        # 释放mamba cache (如果已存在)
        if mamba_exist:
            mamba_ping_pong_track_buffer_to_keep = None

        # 释放请求slot
        self.req_to_token_pool.free(
            req.req_pool_idx,
            free_mamba_cache=free_mamba_cache,
            mamba_ping_pong_track_buffer_to_keep=mamba_ping_pong_track_buffer_to_keep,
        )

8.2 Extend阶段的状态追踪

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# hybrid_linear_attn_backend.py:585-617
def _track_mamba_state_extend(self, forward_batch, h, ssm_states, forward_metadata):
    """
    在extend阶段追踪SSM状态到缓存slot
    """
    if (forward_batch.mamba_track_mask is not None
        and forward_batch.mamba_track_mask.any()):

        h = h.squeeze(0)

        # Case 1: 从h获取 (不对齐)
        if forward_metadata.track_ssm_h_src.numel() > 0:
            ssm_states[forward_metadata.track_ssm_h_dst] = h[
                forward_metadata.track_ssm_h_src
            ].to(ssm_states.dtype, copy=False)

        # Case 2: 从last_recurrent_state获取 (对齐)
        if forward_metadata.track_ssm_final_src.numel() > 0:
            ssm_states[forward_metadata.track_ssm_final_dst] = ssm_states[
                forward_metadata.track_ssm_final_src
            ]

Conv状态追踪:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# hybrid_linear_attn_backend.py:997-1023
# 在forward_extend中

# 使用gather提取需要追踪的conv状态
if (forward_batch.mamba_track_mask is not None
    and forward_batch.mamba_track_mask.any()):

    conv_dst = forward_batch.mamba_track_indices

    # 从输入序列中提取conv window
    mixed_qkv_to_track = mixed_qkv[
        :, forward_metadata.track_conv_indices
    ].transpose(0, 1)

    # 写入缓存
    mask_indices = forward_batch.mamba_track_mask.nonzero(as_tuple=True)[0]
    conv_states[conv_dst[mask_indices]] = mixed_qkv_to_track

8.3 Decode阶段的状态追踪

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# hybrid_linear_attn_backend.py:555-583
def _track_mamba_state_decode(self, forward_batch, conv_states, ssm_states, cache_indices):
    """
    在decode阶段,将更新的状态复制到持久缓存slot
    """
    if forward_batch.mamba_track_mask is not None:
        track_mamba_states_if_needed(
            conv_states,      # 源: 工作slot
            ssm_states,       # 源: 工作slot
            cache_indices,    # 源索引
            forward_batch.mamba_track_mask,
            forward_batch.mamba_track_indices,  # 目标索引
            forward_batch.batch_size,
        )

Triton Kernel实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# hybrid_linear_attn_backend.py:64-122
@triton.jit
def track_mamba_state_if_needed_kernel(
    conv_states_ptr,
    ssm_states_ptr,
    cache_indices_ptr,      # 源索引
    mamba_track_mask_ptr,   # 是否追踪
    mamba_track_indices_ptr,# 目标索引
    ...
):
    batch_idx = tl.program_id(0)

    # 检查mask
    track_mask = tl.load(mamba_track_mask_ptr + batch_idx)
    if not track_mask:
        return

    # 加载索引
    src_idx = tl.load(cache_indices_ptr + batch_idx)
    dst_idx = tl.load(mamba_track_indices_ptr + batch_idx)

    # 复制conv states
    for offset in range(0, conv_state_numel_per_row, BLOCK_SIZE):
        src_ptr = conv_states_ptr + src_idx * conv_state_stride_0 + offset
        dst_ptr = conv_states_ptr + dst_idx * conv_state_stride_0 + offset
        data = tl.load(src_ptr, mask=mask, other=0.0)
        tl.store(dst_ptr, data, mask=mask)

    # 复制SSM states
    for offset in range(0, ssm_state_numel_per_row, BLOCK_SIZE):
        src_ptr = ssm_states_ptr + src_idx * ssm_state_stride_0 + offset
        dst_ptr = ssm_states_ptr + dst_idx * ssm_state_stride_0 + offset
        data = tl.load(src_ptr, mask=mask, other=0.0)
        tl.store(dst_ptr, data, mask=mask)

9. 完整示例

9.1 场景设置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
初始状态:
- Cache: 空
- FLA_CHUNK_SIZE: 64

请求1: "Hello world, this is a test" (8 tokens)
处理: Extend, 无缓存
结果:
- 计算 0-7 (8 tokens)
- 对齐长度 = 0 (8 < 64)
- 不保存mamba cache (不够一个chunk)

请求2: "Hello world, this is a test of the system" (12 tokens)
处理: Extend, 部分缓存
匹配:
- KV cache: 匹配 "Hello world, this is a test" (8 tokens)
- Mamba cache: 匹配 0 tokens (对齐到0)
计算:
- 新计算 8-11 (4 tokens)
- 对齐长度 = 0 (12 < 64)
- 不保存mamba cache

请求3: 非常长的输入 (150 tokens)
假设已处理前100个tokens
处理: Extend (后50个tokens)
匹配:
- KV cache: 匹配前100个
- Mamba cache:
  - 查找最长匹配路径
  - 假设缓存有96个token (1.5个chunk)
  - mamba_branching_seqlen = 96 (对齐到chunk边界)
计算:
- 复用KV: 0-99
- 复用Mamba: 0-95 (从cached mamba state)
- 新计算: 100-149 (50 tokens)
- 对齐长度 = (150 // 64) * 64 = 128
保存:
- KV cache: 保存前128个token (对齐到page_size=1)
- Mamba cache: 保存到位置128的状态
  - conv_state: 最后4个输入 [124-127]
  - ssm_state: chunk边界状态 (128)

请求4: 继续decode (1个token)
处理: Decode
状态:
- mamba_pool_idx: 已分配 (从请求3的缓存)
- conv_state: 位置128的conv window
- ssm_state: 位置128的状态
计算:
- 输入1个token
- conv_state: 更新 (滑动窗口)
- ssm_state: 更新 (递归)
追踪:
- mamba_track_mask: True (每mamba_track_interval个token追踪一次)
- 将状态复制到track slot

9.2 树结构演变

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
初始:
Root

请求3处理后:
Root
└─ [0-127] (value=[0-127], mamba_value=[slot_10])

请求3 cache_finished_req:
(实际只保存到128)
Root
└─ [0-127] (value=[0-127], mamba_value=[slot_10])

请求4 decode N次后:
(假设mamba_track_interval=64)
Root
├─ [0-127] (value=[0-127], mamba_value=[slot_10])
└─ [0-191] (value=[0-191], mamba_value=[slot_11])
    (新的完整序列,包含decode的tokens)

9.3 Tombstone示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
场景: 两个请求共享前缀,然后分叉

请求A: "Common prefix + branch A" (200 tokens)
请求B: "Common prefix + branch B" (200 tokens)

Step 1: 请求A完成后
Root
└─ [0-199] (value=[...], mamba_value=[slot_20])

Step 2: 请求B使用CoW
- 匹配前缀 [0-199]
- Fork mamba_value: slot_20 -> slot_21
- 请求B独立计算branch B

Step 3: 请求B驱逐 (LRU)
Root
├─ [0-199] (value=[...], mamba_value=[slot_20]) # 请求A
└─ [0-199] (value=[...], mamba_value=None)      # Tombstone!

Step 4: 新请求C (150 tokens, 与前缀部分匹配)
匹配:
- KV cache: 匹配 [0-149]
- Mamba cache:
  - 遍历到 [0-199] 节点
  - 发现 mamba_value=None (tombstone)
  - 回退到 Root
  - 实际匹配: 0 tokens

10. 关键数据结构

10.1 ForwardBatch相关

1
2
3
4
5
6
7
8
9
10
11
@dataclass
class ForwardBatch:
    # Mamba追踪
    mamba_track_indices: torch.Tensor      # [b] 目标cache slot
    mamba_track_mask: torch.Tensor         # [b] 是否需要追踪
    mamba_track_seqlens: torch.Tensor       # [b] 追踪到的seq_len

    # Extend相关
    extend_seq_lens: torch.Tensor          # [b] 每个请求的输入长度
    extend_prefix_lens: torch.Tensor        # [b] 每个请求的前缀长度
    extend_seq_lens_cpu: List[int]         # [b] CPU上的seq_len

10.2 ForwardMetadata相关

1
2
3
4
5
6
7
8
9
10
11
@dataclass(kw_only=True)
class ForwardMetadata:
    query_start_loc: torch.Tensor         # 累积序列长度
    mamba_cache_indices: torch.Tensor      # Cache slot索引

    # 追踪相关
    track_conv_indices: Optional[torch.Tensor]  # Conv状态位置
    track_ssm_h_src: Optional[torch.Tensor]     # SSM h源索引
    track_ssm_h_dst: Optional[torch.Tensor]     # SSM h目标索引
    track_ssm_final_src: Optional[torch.Tensor]  # SSM final源索引
    track_ssm_final_dst: Optional[torch.Tensor]  # SSM final目标索引

10.3 Req相关

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Req:
    # Mamba cache slot
    mamba_pool_idx: Optional[int]            # 当前使用的slot

    # Ping-pong buffer
    mamba_ping_pong_track_buffer: Tensor    # [2] 双缓冲
    mamba_next_track_idx: int                # 下一个buffer索引

    # 追踪状态
    mamba_last_track_seqlen: Optional[int]   # 上次追踪的seq_len
    mamba_branching_seqlen: Optional[int]    # 分支点seq_len

    # 前缀信息
    prefix_indices: torch.Tensor             # 匹配的token索引
    last_node: TreeNode                     # Radix树节点

10.4 MambaPool布局

1
2
3
4
5
6
7
8
9
10
class MambaPool:
    @dataclass(frozen=True, kw_only=True)
    class State:
        conv: List[torch.Tensor]             # [num_layers, pool_size+1, dim]
        temporal: torch.Tensor               # [num_layers, pool_size+1, H, K, V]

    @dataclass(frozen=True, kw_only=True)
    class SpeculativeState(State):
        intermediate_ssm: torch.Tensor        # [num_layers, pool_size+1, draft, H, K, V]
        intermediate_conv_window: List[torch.Tensor]  # [num_layers, pool_size+1, draft, dim, K-1]

总结

Linear Attention Prefix Cache的核心机制

  1. Chunk粒度: 只能在 FLA_CHUNK_SIZE (64) 的边界缓存状态
  2. 对齐限制: mamba_branching_seqlen = (len // 64) * 64
  3. Tombstone机制: 允许 KV cache 和 Mamba cache 独立管理
  4. Copy-on-Write: 支持多请求共享前缀时的状态复制
  5. 状态追踪:
    • Conv state: 最后 K-1 个输入
    • SSM state: chunk边界的递归状态
  6. 双缓冲: Ping-pong buffer用于交替追踪状态

与Standard Attention的区别

特性 Standard Attention Linear Attention
缓存粒度 Token Chunk (64 tokens)
状态依赖 无 (完全并行) 递归 (顺序依赖)
前缀匹配 精确匹配任意长度 对齐到chunk边界
缓存结构 单一 KV cache KV + Conv + SSM
驱逐策略 统一驱逐 分离驱逐 (Tombstone)

文件路径索引

组件 文件路径
Mamba Radix Cache python/sglang/srt/mem_cache/mamba_radix_cache.py
Hybrid Linear Backend python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
Chunk Gated Delta Rule python/sglang/srt/layers/attention/fla/chunk.py
Chunk Delta H Kernel python/sglang/srt/layers/attention/fla/chunk_delta_h.py
Mamba Metadata python/sglang/srt/layers/attention/mamba/mamba2_metadata.py
Schedule Batch python/sglang/srt/managers/schedule_batch.py
Memory Pool python/sglang/srt/mem_cache/memory_pool.py

关键代码位置索引

  • Tombstone创建: mamba_radix_cache.py:1099-1103
  • Mamba cache驱逐: mamba_radix_cache.py:729-759
  • Prefix匹配: mamba_radix_cache.py:900-967
  • 分支点计算: mamba_radix_cache.py:955-966
  • Conv追踪索引: hybrid_linear_attn_backend.py:259-298
  • SSM追踪索引: hybrid_linear_attn_backend.py:300-378
  • 状态追踪: hybrid_linear_attn_backend.py:585-617, 555-583
  • CoW实现: mamba_radix_cache.py:440-459, memory_pool.py:304-309
  • Ping-pong buffer: schedule_batch.py:1645-1663
  • Chunk kernel: chunk_delta_h.py:33-272

这个实现充分利用了Linear Attention的递归特性,通过chunk级别的缓存和状态追踪机制,在保持计算效率的同时实现了前缀缓存功能。