mamba-cache

mamba-cache
gogongxtQwen3Next Linear Attention Cache 粒度问题详解
问题现象
用户使用默认参数时,发现 cache 的粒度似乎不是固定的 64 tokens,而是以 chunk prefill size 为边界。
根本原因
1. 默认策略是 “no_buffer”
查看 server_args.py:767-770:
1
2
3
4
5
# Mamba scheduler strategy
if self.mamba_scheduler_strategy == "auto":
# TODO: when extra_buffer is more verified, we can set the default path based on
# [overlap, non-overlap]
self.mamba_scheduler_strategy = "no_buffer"默认配置:
1
2
mamba_scheduler_strategy: str = "auto" # 默认值
# 实际运行时会被设置为: "no_buffer"2. 两种策略的对比
| 特性 | no_buffer (默认) | extra_buffer |
|---|---|---|
| enable_mamba_extra_buffer() | False | True |
| mamba_ping_pong_track_buffer | None | [2] 双缓冲 |
| cache_len | len(token_ids) |
mamba_last_track_seqlen |
| 缓存粒度 | 整个输入长度 | Chunk 对齐 (64) |
| 追踪机制 | 无 | 完整追踪 |
3. “no_buffer” 策略下的缓存行为
缓存保存
查看 mamba_radix_cache.py:589-593:
1
2
3
4
5
cache_len = (
req.mamba_last_track_seqlen
if self.enable_mamba_extra_buffer
else len(token_ids) # no_buffer: 缓存整个序列
)在 no_buffer 模式下:
1
cache_len = len(token_ids) # 例如:输入 210 tokens,就保存 210 tokens缓存匹配
查看 mamba_radix_cache.py:955-963:
1
2
3
4
5
6
7
if len(value) > best_value_len:
fla_chunk_aligned_seqlen = (
sum(len(v) for v in value) // FLA_CHUNK_SIZE
) * FLA_CHUNK_SIZE # 向下对齐到 64
mamba_branching_seqlen = (
fla_chunk_aligned_seqlen if fla_chunk_aligned_seqlen > 0 else None
)关键: 无论保存了多少 tokens,匹配时都会向下对齐到 64 的倍数!
实际场景分析
场景 1: 首次请求 (210 tokens)
1
2
3
4
5
6
7
8
9
输入: 210 tokens
保存缓存 (no_buffer):
- KV cache: 210 tokens
- Mamba cache: 210 tokens (保存了所有状态)
树节点: [0-209]
value: [0-209] # KV cache 索引
mamba_value: [slot_x] # Mamba cache 索引 (210 tokens 的状态)场景 2: 第二个请求 (相同前缀 200 tokens)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
输入: 210 tokens
前缀匹配: 200 tokens
匹配过程:
1. 遍历树,找到节点 [0-209]
2. 计算匹配长度: 200 tokens
3. 对齐到 chunk 边界:
fla_chunk_aligned_seqlen = (200 // 64) * 64 = 192
实际使用:
- KV cache: 200 tokens (完全复用)
- Mamba cache: 192 tokens (只能复用前 192 个)
新计算:
- Token 192-209: 需要重新计算
- Token 210-219: 需要重新计算为什么 Mamba cache 只能复用 192?
因为 Linear Attention 的状态是在 chunk 边界输出的:
- h[0]: 对应 token 0-63 的状态
- h[1]: 对应 token 64-127 的状态
- h[2]: 对应 token 128-191 的状态
- last_recurrent_state: 对应 token 0-209 的状态
但要从位置 192 开始计算,需要 h[2] (token 191 的状态),而不是 last_recurrent_state。
场景 3: Decode 阶段
1
2
3
4
5
6
7
8
9
10
11
12
当前状态: 210 tokens
每步 decode: 1 token
在 no_buffer 模式下:
- 不使用 mamba_ping_pong_track_buffer
- 不追踪 mamba_last_track_seqlen
- 直接使用 req.mamba_pool_idx 的状态
每步 decode:
- 更新 conv_state (in-place)
- 更新 ssm_state (in-place)
- 不需要额外追踪为什么看起来是 Chunk Prefill Size 粒度?
实际观察
当用户使用默认参数时:
1
2
3
# 默认配置
mamba_scheduler_strategy = "auto" # → "no_buffer"
chunked_prefill_size = 2048 (或根据 GPU 设置)实际缓存行为:
- Prefill 阶段 (例如输入 2048 tokens):
- 保存: 2048 tokens 的完整状态
- 树节点: [0-2047]
- 后续请求 (相同前缀):
- KV cache: 精确匹配任意长度
- Mamba cache: 向下对齐到 64 的倍数
- 例如: 匹配 1500 tokens → 实际复用 1472 tokens (1500 // 64 * 64)
- 为什么看起来是 chunked_prefill_size 粒度?
- 因为通常 prefill 的输入长度都是 chunked_prefill_size 的倍数
- 例如: 2048, 4096, 8192…
- 这些数字本身就能被 64 整除
- 所以看不到对齐的损失
真实的对齐损失
1
2
3
4
5
6
7
8
# 输入长度不是 64 的倍数
输入: 2100 tokens
对齐后: 1952 tokens (2100 // 64 * 64)
损失: 148 tokens
输入: 210 tokens
对齐后: 192 tokens
损失: 18 tokens解决方案
方案 1: 使用 extra_buffer 策略 (推荐)
1
2
3
4
5
# 启动时指定
python -m sglang.launch_server \
--model <model_path> \
--mamba-scheduler-strategy extra_buffer \
...效果:
- 启用 ping-pong buffer
- 以 FLA_CHUNK_SIZE (64) 为粒度保存缓存
- 更精确的缓存复用
- 更少的重复计算
代价:
- 额外的内存开销 (ping-pong buffer)
- 每隔 mamba_track_interval (默认 256) 进行状态追踪
方案 2: 调整输入长度对齐
在 extra_buffer 模式下,确保 prefill 输入长度是 64 的倍数:
1
2
3
4
5
6
# 例如: padding 到 64 的倍数
input_len = 2100
aligned_len = ((input_len + 63) // 64) * 64 # 2112
# 添加 padding tokens
padded_input = input_ids + [pad_token_id] * (aligned_len - input_len)方案 3: 接受默认行为
在大多数场景下,“no_buffer” 策略已经足够好:
- KV cache 仍然是 token 粒度的,可以精确匹配
- Mamba cache 的对齐损失通常很小 (< 3%)
- 不需要额外的内存开销
代码位置索引
| 功能 | 文件路径 | 行号 |
|---|---|---|
| 策略选择 | server_args.py |
767-770 |
| enable_mamba_extra_buffer | server_args.py |
4538-4539 |
| 缓存长度计算 | mamba_radix_cache.py |
589-593 |
| 对齐计算 | mamba_radix_cache.py |
955-963 |
| mamba_track_interval | server_args.py |
474 |
| 追踪更新 | scheduler_output_processor_mixin.py |
467-491 |
总结
- 默认策略是 “no_buffer”,不是 “extra_buffer”
- FLA_CHUNK_SIZE (64) 是硬编码的,无法通过参数修改
- 对齐损失发生在匹配时,而不是保存时
- 要获得精确的 chunk 粒度缓存,需要显式设置
--mamba-scheduler-strategy extra_buffer - 在大多数情况下,默认的 “no_buffer” 策略已经提供了良好的性能平衡
关键参数
1
2
3
4
5
# 相关启动参数
--mamba-scheduler-strategy [auto|no_buffer|extra_buffer] # 默认: auto (→ no_buffer)
--mamba-track-interval 256 # extra_buffer 模式下的追踪间隔
--chunked-prefill-size 2048 # Prefill chunk 大小
--max-mamba-cache-size 4096 # Mamba cache 最大 slot 数注意: FLA_CHUNK_SIZE (64) 是内核常量,无法通过启动参数修改。修改它需要重新编译 Triton kernel。
评论
匿名评论隐私政策




