qwen3next-cache

qwen3next-cache
gogongxtQwen3Next Linear Attention Prefix Cache 实现深度分析
目录
- 核心概念
- 为什么Linear Attention不能做到Token粒度缓存
- Chunk粒度缓存机制
- Radix Cache的Tombstone机制
- Prefix匹配流程
- 状态追踪(Tracking)机制
- Copy-on-Write实现
- 缓存保存流程
- 完整示例
- 关键数据结构
1. 核心概念
1.1 Linear Attention的递归性质
Linear Attention (Gated Delta Net) 与标准 Attention 的根本区别在于其递归计算特性:
1
2
3
4
5
6
标准 Attention (Token级别并行):
Output[t] = f(Q[t], K[0:t], V[0:t]) # 每个token独立计算
Linear Attention (递归依赖):
H[t] = f(H[t-1], Q[t], K[t], V[t]) # 状态递归传递
Output[t] = g(H[t], Z[t]) # 基于状态生成输出这意味着:
- Standard Attention: 可以缓存每个token的 K/V,任意跳过已计算token
- Linear Attention: 必须从最后一个有效状态开始顺序计算
1.2 FLA Chunk机制
Flash Linear Attention (FLA) 引入了 Chunk 概念来平衡并行度和状态管理:
1
2
3
4
5
# chunk_delta_h.py
CHUNK_SIZE = 64 # 固定chunk大小
# 序列被分成多个chunk:
# [0-63], [64-127], [128-191], ...Chunk内部:
- 完全并行计算(类似标准attention)
- 输出中间状态
h和最终状态last_recurrent_state
Chunk之间:
- 递归传递:
h[i]作为chunk[i+1]的initial_state - 只有chunk边界的状态可以被缓存和复用
2. 为什么Linear Attention不能做到Token粒度缓存
2.1 状态依赖链
1
2
3
4
5
6
7
8
9
10
11
12
13
# Linear Attention的状态更新伪代码
def linear_attention(tokens):
h = initial_state
outputs = []
for t in range(len(tokens)):
# 每一步都依赖前一步的状态
h = delta_rule(h, tokens[t])
outputs.append(gate(h, tokens[t]))
return outputs
# 如果要从位置 n 开始计算,必须有 h[n-1] 的状态2.2 Chunk边界的必要性
FLA的设计只允许在 Chunk边界 (64的倍数位置) 进行状态分割:
1
2
3
4
5
序列: [0, 1, 2, ..., 63, 64, 65, ..., 127, 128, ...]
|<- Chunk 0 ->| |<--- Chunk 1 --->| |
✓ 可缓存 ✓ 可缓存 ✓ 可缓存
位置 65: ✗ 不能作为缓存边界 (不在chunk边界)原因:
last_recurrent_state只在chunk末尾输出- 中间位置的state需要从头或从上一个chunk边界开始计算
- 计算中间位置的成本 > 直接从头计算
3. Chunk粒度缓存机制
3.1 核心常量
1
2
3
4
5
# chunk_delta_h.py
CHUNK_SIZE = 64 # FLA固定的chunk大小
# mamba_radix_cache.py
from sglang.srt.layers.attention.fla.chunk_delta_h import CHUNK_SIZE as FLA_CHUNK_SIZE3.2 对齐计算
1
2
3
4
5
6
7
8
9
10
# mamba_radix_cache.py:955-963
# Calculate the branching point. It is defined as the last aligned position that
# does not have a mamba value.
if len(value) > best_value_len:
fla_chunk_aligned_seqlen = (
sum(len(v) for v in value) // FLA_CHUNK_SIZE
) * FLA_CHUNK_SIZE # 向下对齐到64的倍数
mamba_branching_seqlen = (
fla_chunk_aligned_seqlen if fla_chunk_aligned_seqlen > 0 else None
)示例:
1
2
3
4
5
6
7
8
实际匹配长度: 150 tokens
对齐后长度: 96 tokens (150 // 64 * 64 = 2 * 64 = 96)
未匹配部分: 54 tokens (需要重新计算)
为什么是96而不是128?
- 因为匹配路径上的节点在位置96之后有 "tombstone" (mamba_value=None)
- Tombstone表示该位置没有有效的mamba cache
- 必须回退到最后一个有效chunk边界3.3 状态存储布局
1
2
3
4
5
6
7
8
9
10
11
12
13
# chunk_gated_delta_rule_fwd_h 返回
h: [B, NT, H, K, V] # NT = num_chunks
# h[0]: chunk 0 后的状态
# h[1]: chunk 1 后的状态
# h[2]: chunk 2 后的状态
# ...
last_recurrent_state: [N, H, K, V] # N = batch_size
# 整个序列的最终状态
# 缓存时只能保存:
# 1. h[i] - chunk边界的状态
# 2. last_recurrent_state - 序列末尾状态4. Radix Cache的Tombstone机制
4.1 树节点结构
1
2
3
4
5
6
class TreeNode:
value: Optional[torch.Tensor] # KV cache索引
mamba_value: Optional[torch.Tensor] # Mamba cache索引
full_lock_ref: int # KV cache锁计数
mamba_lock_ref: int # Mamba cache锁计数4.2 Tombstone定义
Tombstone: mamba_value = None 但
value != None 的节点
1
2
3
4
5
# 示例树结构:
Root
├─ [0-95] (value=[...], mamba_value=[...]) # 有完整cache
├─ [96-150] (value=[...], mamba_value=None) # Tombstone! 只有KV cache
└─ [96-191] (value=[...], mamba_value=[...]) # 另一条路径,有完整cache为什么需要Tombstone?
- Linear attention和Full attention的缓存粒度不同
- KV cache (Full attention): token粒度,可以精确到任意token
- Mamba cache (Linear attention): chunk粒度 (64的倍数)
- 允许两者独立管理,提高缓存利用率
4.3 分离驱逐策略
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# mamba_radix_cache.py
def evict_mamba(self, mamba_num: int):
"""只驱逐mamba cache,保留KV cache"""
while need_evict:
if len(x.children) > 0:
# 内部节点: 只释放mamba,变成tombstone
self.req_to_token_pool.mamba_pool.free(x.mamba_value)
self._tombstone_internal_node(x) # mamba_value = None
else:
# 叶子节点: 同时释放两者
self._evict_leaf_node(x, True)
def evict(self, full_num_tokens: int):
"""驱逐KV cache (叶子节点),同时释放mamba"""
# 只能驱逐叶子节点
# 同时释放 value 和 mamba_value5. Prefix匹配流程
5.1 MatchPrefix算法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# mamba_radix_cache.py:900-967
def _match_prefix_helper(self, key: RadixKey) -> Tuple[List, TreeNode, Optional[int]]:
"""
返回: (matched_value_list, last_node, mamba_branching_seqlen)
核心逻辑:
1. 遍历radix树,找到最长匹配路径
2. 记录最后一个有 mamba_value 的节点
3. 计算mamba_branching_seqlen
"""
node = self.root_node
value = []
best_value_len = 0
best_last_node = node
# 遍历树
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
# 更新最佳节点: 当前节点必须有mamba_value
if node.mamba_value is not None:
best_value_len = len(value)
best_last_node = node
# ... 继续匹配
# 计算branching point
if len(value) > best_value_len:
# 向下对齐到chunk边界
fla_chunk_aligned_seqlen = (
sum(len(v) for v in value) // FLA_CHUNK_SIZE
) * FLA_CHUNK_SIZE
mamba_branching_seqlen = (
fla_chunk_aligned_seqlen if fla_chunk_aligned_seqlen > 0 else None
)
else:
mamba_branching_seqlen = None
return value[:best_value_len], best_last_node, mamba_branching_seqlen5.2 匹配示例
场景: 请求的token序列与缓存有部分匹配
1
2
3
4
5
6
7
请求序列: "A B C D E F G H I J K L M N ..."
缓存1: "A B C D E F" (150 tokens, mamba_value存在)
↓ 匹配 150 tokens
但只有前 96 个是chunk对齐的
↓
实际使用: 96 tokens (从缓存)
需要计算: 54 tokens (新计算)流程:
- 遍历树,找到匹配节点:
[0-150]有mamba_value - 计算
mamba_branching_seqlen = 96(向下对齐) - 返回:
device_indices: 96个token的KV cache索引mamba_value: mamba cache索引mamba_branching_seqlen: 96
6. 状态追踪(Tracking)机制
6.1 为什么需要Tracking?
在 Prefill阶段,计算完新的chunk后,需要将其状态保存到cache中。但需要追踪:
- Conv State: 卷积窗口的最后
conv_kernel_size个输入 - SSM State: 每个chunk边界后的递归状态
6.2 Track相关字段
1
2
3
4
5
6
7
8
9
10
# ForwardBatch (schedule_batch.py)
mamba_track_indices: torch.Tensor # [b] 追踪目标cache slot
mamba_track_mask: torch.Tensor # [b] 是否需要追踪
mamba_track_seqlens: torch.Tensor # [b] 追踪的序列长度
# Req (schedule_batch.py)
mamba_ping_pong_track_buffer: Tensor # [2] ping-pong缓冲区
mamba_next_track_idx: int # 下一个使用的buffer索引
mamba_last_track_seqlen: Optional[int] # 上次追踪的seq_len
mamba_branching_seqlen: Optional[int] # 分支点seq_len6.3 Track计算流程
6.3.1 Prepare阶段
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# schedule_batch.py:1612-1664
def _mamba_radix_cache_v2_req_prepare_for_extend(self, req, ...):
# 计算是否需要追踪
mask = (req.extend_input_len // FLA_CHUNK_SIZE) * FLA_CHUNK_SIZE > 0
mamba_track_mask_cpu.append(mask)
if mask:
# 完整追踪长度
mamba_track_seqlen = len(req.prefix_indices) + req.extend_input_len
# 对齐后的追踪长度
mamba_track_seqlen_aligned = (
len(req.prefix_indices)
+ (req.extend_input_len // FLA_CHUNK_SIZE) * FLA_CHUNK_SIZE
)
# 处理分支点
if req.mamba_branching_seqlen is not None:
branching_seqlen_aligned_mask = (
req.mamba_branching_seqlen - len(req.prefix_indices)
) % FLA_CHUNK_SIZE == 0
if (req.mamba_branching_seqlen > len(req.prefix_indices)
and req.mamba_branching_seqlen < mamba_track_seqlen
and branching_seqlen_aligned_mask):
# 需要追踪分支点
mamba_track_seqlen = req.mamba_branching_seqlen + 1
mamba_track_seqlen_aligned = req.mamba_branching_seqlen
req.mamba_last_track_seqlen = mamba_track_seqlen_aligned6.3.2 Conv State追踪索引计算
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# hybrid_linear_attn_backend.py:259-298
def _init_track_conv_indices(self, query_start_loc, forward_batch):
"""
计算需要追踪的卷积状态的位置
Conv state存储最后 conv_kernel_size 个输入
需要从输入序列中提取这些位置的数据
"""
conv_state_len = self.conv_states_shape[-1] # conv_kernel_size
# 计算对齐长度
lens_to_track = forward_batch.mamba_track_seqlens - forward_batch.extend_prefix_lens
aligned_len = (lens_to_track // FLA_CHUNK_SIZE) * FLA_CHUNK_SIZE
# 起始位置 = 对齐长度 - conv_kernel_size
start_indices = query_start_loc[:-1] + aligned_len - conv_state_len
start_indices = start_indices[forward_batch.mamba_track_mask]
# 创建索引: [num_tracked, conv_kernel_size]
indices = start_indices.unsqueeze(-1) + torch.arange(conv_state_len, ...)
return indices.clamp(0, query_start_loc[-1] - 1)示例:
1
2
3
4
5
6
7
8
9
10
11
extend_prefix_lens = 100 (已有cache)
mamba_track_seqlens = 300 (追踪到300)
extend_input_len = 210 (新输入210个tokens)
lens_to_track = 300 - 100 = 200
aligned_len = (200 // 64) * 64 = 192
如果 conv_kernel_size = 4:
start_indices = query_start_loc[:-1] + 192 - 4 = query_start_loc[:-1] + 188
提取位置: [188, 189, 190, 191] (最后4个输入)6.3.3 SSM State追踪索引计算
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# hybrid_linear_attn_backend.py:300-378
def _init_track_ssm_indices(self, mamba_cache_indices, forward_batch):
"""
计算SSM状态的追踪索引
FLA kernel输出:
- h: [NT, H, K, V] - 每个chunk边界的状态
- last_recurrent_state: [N, H, K, V] - 序列末尾状态
需要决定从h还是last_recurrent_state提取状态
"""
# 计算chunk数量
num_h_states = (extend_seq_lens - 1) // FLA_CHUNK_SIZE + 1
# 计算偏移
track_ssm_src_offset = torch.zeros_like(num_h_states)
track_ssm_src_offset[1:] = torch.cumsum(num_h_states[:-1], dim=0)
# 过滤需要追踪的请求
lens_to_track = mamba_track_seqlens - prefix_lens
lens_masked = lens_to_track[mamba_track_mask]
offset_masked = track_ssm_src_offset[mamba_track_mask]
dst_masked = mamba_track_indices[mamba_track_mask]
# 判断是否对齐
is_aligned = (lens_masked % FLA_CHUNK_SIZE) == 0
# Case 1: 对齐 - 从 last_recurrent_state
track_ssm_final_src = mamba_cache_indices[mamba_track_mask][is_aligned]
track_ssm_final_dst = dst_masked[is_aligned]
# Case 2: 不对齐 - 从 h (中间状态)
not_aligned = ~is_aligned
track_ssm_h_src = offset_masked[not_aligned] + (
lens_masked[not_aligned] // FLA_CHUNK_SIZE
)
track_ssm_h_dst = dst_masked[not_aligned]
return track_ssm_h_src, track_ssm_h_dst, track_ssm_final_src, track_ssm_final_dst示例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
extend_seq_lens = 210
prefix_lens = 100
lens_to_track = 110
num_h_states = (210 - 1) // 64 + 1 = 4
chunk分布:
- h[0]: 对应位置 [0-63]
- h[1]: 对应位置 [64-127]
- h[2]: 对应位置 [128-191]
- last_recurrent_state: 对应位置 [0-210]
lens_masked = 110
is_aligned = (110 % 64 == 0) = False
需要从 h 获取:
110 // 64 = 1
从 h[1 + offset] 获取状态7. Copy-on-Write实现
7.1 为什么需要CoW?
当多个请求共享同一个前缀时,如果其中某个请求需要继续计算,需要复制Mamba cache:
1
2
3
4
5
6
7
8
9
10
11
请求A: "Hello world" [已缓存]
请求B: "Hello world, how are you?" [复用A的cache]
问题:
- A和B共享同一个 mamba_value
- 如果B继续decode,会修改cache
- A的cache会被破坏!
解决: Copy-on-Write
- B fork 一份独立的cache
- 两者互不干扰7.2 CoW实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# mamba_radix_cache.py:440-459
def match_prefix(self, key, cow_mamba=False, req=None):
value, last_node, mamba_branching_seqlen = self._match_prefix_helper(key)
# Copy-on-Write for Mamba
if cow_mamba and last_node.mamba_value is not None:
if req.mamba_pool_idx is None:
# 分配新的cache slot
dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
if dst_index is None:
# 空间不足,驱逐后重试
self.inc_lock_ref(last_node)
self.evict_mamba(1)
dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
self.dec_lock_ref(last_node)
# 复制状态
src_index = last_node.mamba_value
self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)
req.mamba_pool_idx = dst_index
else:
# 已有独立cache,更新
src_index = last_node.mamba_value
dst_index = req.mamba_pool_idx.unsqueeze(0)
self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)
return MatchResult(device_indices=torch.cat(value), ...)7.3 Fork实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# memory_pool.py:304-309
def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]:
"""Fork一个新的mamba cache副本"""
dst_index = self.alloc(1)
if dst_index == None:
return None
self.copy_from(src_index, dst_index)
return dst_index
def copy_from(self, src_index, dst_index):
"""复制conv和temporal状态"""
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][:, src_index]
self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[:, src_index]7.4 使用场景
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# mamba_radix_cache.py:576-689
def cache_unfinished_req(self, req: Req):
"""
缓存未完成的请求 (chunked prefill)
"""
# Fork mamba cache
mamba_value = self.req_to_token_pool.get_mamba_indices(req.req_pool_idx)
mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value)
# 如果fork失败,驱逐后重试
if mamba_value_forked is None:
self.evict_mamba(1)
mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value)
# 插入新的树节点
new_prefix_len, mamba_exist = self.insert(
RadixKey(page_aligned_token_ids, req.extra_key),
page_aligned_kv_indices,
mamba_value_forked,
)
# 如果已有mamba cache,释放fork的副本
if mamba_exist:
self.req_to_token_pool.mamba_pool.free(mamba_value_forked)8. 缓存保存流程
8.1 Cache Finished Req
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# mamba_radix_cache.py:480-572
def cache_finished_req(self, req: Req, is_insert: bool = True):
# 获取已提交的KV长度
kv_committed_len = req.pop_committed_kv_cache()
token_ids = (req.origin_input_ids + req.output_ids)[:kv_committed_len]
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :kv_committed_len
]
if is_insert:
# 计算实际缓存长度
cache_len = (
req.mamba_last_track_seqlen
if self.enable_mamba_extra_buffer
else len(token_ids)
)
# 对齐到page_size
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len]
page_aligned_token_ids = token_ids[:page_aligned_len]
# 从ping-pong buffer获取mamba cache
if self.enable_mamba_extra_buffer:
mamba_ping_pong_track_buffer_to_keep = (
self.req_to_token_pool.get_mamba_ping_pong_other_idx(
req.mamba_ping_pong_track_buffer
)
)
mamba_value = (
req.mamba_ping_pong_track_buffer[
mamba_ping_pong_track_buffer_to_keep
].unsqueeze(-1).clone()
)
else:
mamba_value = req.mamba_pool_idx.unsqueeze(-1).clone()
# 插入radix树
new_prefix_len, mamba_exist = self.insert(
RadixKey(page_aligned_token_ids, req.extra_key),
page_aligned_kv_indices,
mamba_value,
)
# 释放未使用的KV cache
self.token_to_kv_pool_allocator.free(
kv_indices[req.cache_protected_len : new_prefix_len]
)
# 释放mamba cache (如果已存在)
if mamba_exist:
mamba_ping_pong_track_buffer_to_keep = None
# 释放请求slot
self.req_to_token_pool.free(
req.req_pool_idx,
free_mamba_cache=free_mamba_cache,
mamba_ping_pong_track_buffer_to_keep=mamba_ping_pong_track_buffer_to_keep,
)8.2 Extend阶段的状态追踪
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# hybrid_linear_attn_backend.py:585-617
def _track_mamba_state_extend(self, forward_batch, h, ssm_states, forward_metadata):
"""
在extend阶段追踪SSM状态到缓存slot
"""
if (forward_batch.mamba_track_mask is not None
and forward_batch.mamba_track_mask.any()):
h = h.squeeze(0)
# Case 1: 从h获取 (不对齐)
if forward_metadata.track_ssm_h_src.numel() > 0:
ssm_states[forward_metadata.track_ssm_h_dst] = h[
forward_metadata.track_ssm_h_src
].to(ssm_states.dtype, copy=False)
# Case 2: 从last_recurrent_state获取 (对齐)
if forward_metadata.track_ssm_final_src.numel() > 0:
ssm_states[forward_metadata.track_ssm_final_dst] = ssm_states[
forward_metadata.track_ssm_final_src
]Conv状态追踪:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# hybrid_linear_attn_backend.py:997-1023
# 在forward_extend中
# 使用gather提取需要追踪的conv状态
if (forward_batch.mamba_track_mask is not None
and forward_batch.mamba_track_mask.any()):
conv_dst = forward_batch.mamba_track_indices
# 从输入序列中提取conv window
mixed_qkv_to_track = mixed_qkv[
:, forward_metadata.track_conv_indices
].transpose(0, 1)
# 写入缓存
mask_indices = forward_batch.mamba_track_mask.nonzero(as_tuple=True)[0]
conv_states[conv_dst[mask_indices]] = mixed_qkv_to_track8.3 Decode阶段的状态追踪
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# hybrid_linear_attn_backend.py:555-583
def _track_mamba_state_decode(self, forward_batch, conv_states, ssm_states, cache_indices):
"""
在decode阶段,将更新的状态复制到持久缓存slot
"""
if forward_batch.mamba_track_mask is not None:
track_mamba_states_if_needed(
conv_states, # 源: 工作slot
ssm_states, # 源: 工作slot
cache_indices, # 源索引
forward_batch.mamba_track_mask,
forward_batch.mamba_track_indices, # 目标索引
forward_batch.batch_size,
)Triton Kernel实现:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# hybrid_linear_attn_backend.py:64-122
@triton.jit
def track_mamba_state_if_needed_kernel(
conv_states_ptr,
ssm_states_ptr,
cache_indices_ptr, # 源索引
mamba_track_mask_ptr, # 是否追踪
mamba_track_indices_ptr,# 目标索引
...
):
batch_idx = tl.program_id(0)
# 检查mask
track_mask = tl.load(mamba_track_mask_ptr + batch_idx)
if not track_mask:
return
# 加载索引
src_idx = tl.load(cache_indices_ptr + batch_idx)
dst_idx = tl.load(mamba_track_indices_ptr + batch_idx)
# 复制conv states
for offset in range(0, conv_state_numel_per_row, BLOCK_SIZE):
src_ptr = conv_states_ptr + src_idx * conv_state_stride_0 + offset
dst_ptr = conv_states_ptr + dst_idx * conv_state_stride_0 + offset
data = tl.load(src_ptr, mask=mask, other=0.0)
tl.store(dst_ptr, data, mask=mask)
# 复制SSM states
for offset in range(0, ssm_state_numel_per_row, BLOCK_SIZE):
src_ptr = ssm_states_ptr + src_idx * ssm_state_stride_0 + offset
dst_ptr = ssm_states_ptr + dst_idx * ssm_state_stride_0 + offset
data = tl.load(src_ptr, mask=mask, other=0.0)
tl.store(dst_ptr, data, mask=mask)9. 完整示例
9.1 场景设置
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
初始状态:
- Cache: 空
- FLA_CHUNK_SIZE: 64
请求1: "Hello world, this is a test" (8 tokens)
处理: Extend, 无缓存
结果:
- 计算 0-7 (8 tokens)
- 对齐长度 = 0 (8 < 64)
- 不保存mamba cache (不够一个chunk)
请求2: "Hello world, this is a test of the system" (12 tokens)
处理: Extend, 部分缓存
匹配:
- KV cache: 匹配 "Hello world, this is a test" (8 tokens)
- Mamba cache: 匹配 0 tokens (对齐到0)
计算:
- 新计算 8-11 (4 tokens)
- 对齐长度 = 0 (12 < 64)
- 不保存mamba cache
请求3: 非常长的输入 (150 tokens)
假设已处理前100个tokens
处理: Extend (后50个tokens)
匹配:
- KV cache: 匹配前100个
- Mamba cache:
- 查找最长匹配路径
- 假设缓存有96个token (1.5个chunk)
- mamba_branching_seqlen = 96 (对齐到chunk边界)
计算:
- 复用KV: 0-99
- 复用Mamba: 0-95 (从cached mamba state)
- 新计算: 100-149 (50 tokens)
- 对齐长度 = (150 // 64) * 64 = 128
保存:
- KV cache: 保存前128个token (对齐到page_size=1)
- Mamba cache: 保存到位置128的状态
- conv_state: 最后4个输入 [124-127]
- ssm_state: chunk边界状态 (128)
请求4: 继续decode (1个token)
处理: Decode
状态:
- mamba_pool_idx: 已分配 (从请求3的缓存)
- conv_state: 位置128的conv window
- ssm_state: 位置128的状态
计算:
- 输入1个token
- conv_state: 更新 (滑动窗口)
- ssm_state: 更新 (递归)
追踪:
- mamba_track_mask: True (每mamba_track_interval个token追踪一次)
- 将状态复制到track slot9.2 树结构演变
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
初始:
Root
请求3处理后:
Root
└─ [0-127] (value=[0-127], mamba_value=[slot_10])
请求3 cache_finished_req:
(实际只保存到128)
Root
└─ [0-127] (value=[0-127], mamba_value=[slot_10])
请求4 decode N次后:
(假设mamba_track_interval=64)
Root
├─ [0-127] (value=[0-127], mamba_value=[slot_10])
└─ [0-191] (value=[0-191], mamba_value=[slot_11])
(新的完整序列,包含decode的tokens)9.3 Tombstone示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
场景: 两个请求共享前缀,然后分叉
请求A: "Common prefix + branch A" (200 tokens)
请求B: "Common prefix + branch B" (200 tokens)
Step 1: 请求A完成后
Root
└─ [0-199] (value=[...], mamba_value=[slot_20])
Step 2: 请求B使用CoW
- 匹配前缀 [0-199]
- Fork mamba_value: slot_20 -> slot_21
- 请求B独立计算branch B
Step 3: 请求B驱逐 (LRU)
Root
├─ [0-199] (value=[...], mamba_value=[slot_20]) # 请求A
└─ [0-199] (value=[...], mamba_value=None) # Tombstone!
Step 4: 新请求C (150 tokens, 与前缀部分匹配)
匹配:
- KV cache: 匹配 [0-149]
- Mamba cache:
- 遍历到 [0-199] 节点
- 发现 mamba_value=None (tombstone)
- 回退到 Root
- 实际匹配: 0 tokens10. 关键数据结构
10.1 ForwardBatch相关
1
2
3
4
5
6
7
8
9
10
11
@dataclass
class ForwardBatch:
# Mamba追踪
mamba_track_indices: torch.Tensor # [b] 目标cache slot
mamba_track_mask: torch.Tensor # [b] 是否需要追踪
mamba_track_seqlens: torch.Tensor # [b] 追踪到的seq_len
# Extend相关
extend_seq_lens: torch.Tensor # [b] 每个请求的输入长度
extend_prefix_lens: torch.Tensor # [b] 每个请求的前缀长度
extend_seq_lens_cpu: List[int] # [b] CPU上的seq_len10.2 ForwardMetadata相关
1
2
3
4
5
6
7
8
9
10
11
@dataclass(kw_only=True)
class ForwardMetadata:
query_start_loc: torch.Tensor # 累积序列长度
mamba_cache_indices: torch.Tensor # Cache slot索引
# 追踪相关
track_conv_indices: Optional[torch.Tensor] # Conv状态位置
track_ssm_h_src: Optional[torch.Tensor] # SSM h源索引
track_ssm_h_dst: Optional[torch.Tensor] # SSM h目标索引
track_ssm_final_src: Optional[torch.Tensor] # SSM final源索引
track_ssm_final_dst: Optional[torch.Tensor] # SSM final目标索引10.3 Req相关
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Req:
# Mamba cache slot
mamba_pool_idx: Optional[int] # 当前使用的slot
# Ping-pong buffer
mamba_ping_pong_track_buffer: Tensor # [2] 双缓冲
mamba_next_track_idx: int # 下一个buffer索引
# 追踪状态
mamba_last_track_seqlen: Optional[int] # 上次追踪的seq_len
mamba_branching_seqlen: Optional[int] # 分支点seq_len
# 前缀信息
prefix_indices: torch.Tensor # 匹配的token索引
last_node: TreeNode # Radix树节点10.4 MambaPool布局
1
2
3
4
5
6
7
8
9
10
class MambaPool:
@dataclass(frozen=True, kw_only=True)
class State:
conv: List[torch.Tensor] # [num_layers, pool_size+1, dim]
temporal: torch.Tensor # [num_layers, pool_size+1, H, K, V]
@dataclass(frozen=True, kw_only=True)
class SpeculativeState(State):
intermediate_ssm: torch.Tensor # [num_layers, pool_size+1, draft, H, K, V]
intermediate_conv_window: List[torch.Tensor] # [num_layers, pool_size+1, draft, dim, K-1]总结
Linear Attention Prefix Cache的核心机制
- Chunk粒度: 只能在 FLA_CHUNK_SIZE (64) 的边界缓存状态
- 对齐限制:
mamba_branching_seqlen = (len // 64) * 64 - Tombstone机制: 允许 KV cache 和 Mamba cache 独立管理
- Copy-on-Write: 支持多请求共享前缀时的状态复制
- 状态追踪:
- Conv state: 最后 K-1 个输入
- SSM state: chunk边界的递归状态
- 双缓冲: Ping-pong buffer用于交替追踪状态
与Standard Attention的区别
| 特性 | Standard Attention | Linear Attention |
|---|---|---|
| 缓存粒度 | Token | Chunk (64 tokens) |
| 状态依赖 | 无 (完全并行) | 递归 (顺序依赖) |
| 前缀匹配 | 精确匹配任意长度 | 对齐到chunk边界 |
| 缓存结构 | 单一 KV cache | KV + Conv + SSM |
| 驱逐策略 | 统一驱逐 | 分离驱逐 (Tombstone) |
文件路径索引
| 组件 | 文件路径 |
|---|---|
| Mamba Radix Cache | python/sglang/srt/mem_cache/mamba_radix_cache.py |
| Hybrid Linear Backend | python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py |
| Chunk Gated Delta Rule | python/sglang/srt/layers/attention/fla/chunk.py |
| Chunk Delta H Kernel | python/sglang/srt/layers/attention/fla/chunk_delta_h.py |
| Mamba Metadata | python/sglang/srt/layers/attention/mamba/mamba2_metadata.py |
| Schedule Batch | python/sglang/srt/managers/schedule_batch.py |
| Memory Pool | python/sglang/srt/mem_cache/memory_pool.py |
关键代码位置索引
- Tombstone创建:
mamba_radix_cache.py:1099-1103 - Mamba cache驱逐:
mamba_radix_cache.py:729-759 - Prefix匹配:
mamba_radix_cache.py:900-967 - 分支点计算:
mamba_radix_cache.py:955-966 - Conv追踪索引:
hybrid_linear_attn_backend.py:259-298 - SSM追踪索引:
hybrid_linear_attn_backend.py:300-378 - 状态追踪:
hybrid_linear_attn_backend.py:585-617, 555-583 - CoW实现:
mamba_radix_cache.py:440-459, memory_pool.py:304-309 - Ping-pong buffer:
schedule_batch.py:1645-1663 - Chunk kernel:
chunk_delta_h.py:33-272
这个实现充分利用了Linear Attention的递归特性,通过chunk级别的缓存和状态追踪机制,在保持计算效率的同时实现了前缀缓存功能。
评论
匿名评论隐私政策




