mamba-cache

Qwen3Next Linear Attention Cache 粒度问题详解

问题现象

用户使用默认参数时,发现 cache 的粒度似乎不是固定的 64 tokens,而是以 chunk prefill size 为边界。

根本原因

1. 默认策略是 “no_buffer”

查看 server_args.py:767-770:

1
2
3
4
5
# Mamba scheduler strategy
if self.mamba_scheduler_strategy == "auto":
    # TODO: when extra_buffer is more verified, we can set the default path based on
    #       [overlap, non-overlap]
    self.mamba_scheduler_strategy = "no_buffer"

默认配置:

1
2
mamba_scheduler_strategy: str = "auto"  # 默认值
# 实际运行时会被设置为: "no_buffer"

2. 两种策略的对比

特性 no_buffer (默认) extra_buffer
enable_mamba_extra_buffer() False True
mamba_ping_pong_track_buffer None [2] 双缓冲
cache_len len(token_ids) mamba_last_track_seqlen
缓存粒度 整个输入长度 Chunk 对齐 (64)
追踪机制 完整追踪

3. “no_buffer” 策略下的缓存行为

缓存保存

查看 mamba_radix_cache.py:589-593:

1
2
3
4
5
cache_len = (
    req.mamba_last_track_seqlen
    if self.enable_mamba_extra_buffer
    else len(token_ids)  # no_buffer: 缓存整个序列
)

在 no_buffer 模式下:

1
cache_len = len(token_ids)  # 例如:输入 210 tokens,就保存 210 tokens

缓存匹配

查看 mamba_radix_cache.py:955-963:

1
2
3
4
5
6
7
if len(value) > best_value_len:
    fla_chunk_aligned_seqlen = (
        sum(len(v) for v in value) // FLA_CHUNK_SIZE
    ) * FLA_CHUNK_SIZE  # 向下对齐到 64
    mamba_branching_seqlen = (
        fla_chunk_aligned_seqlen if fla_chunk_aligned_seqlen > 0 else None
    )

关键: 无论保存了多少 tokens,匹配时都会向下对齐到 64 的倍数

实际场景分析

场景 1: 首次请求 (210 tokens)

1
2
3
4
5
6
7
8
9
输入: 210 tokens

保存缓存 (no_buffer):
- KV cache: 210 tokens
- Mamba cache: 210 tokens (保存了所有状态)

树节点: [0-209]
  value: [0-209]        # KV cache 索引
  mamba_value: [slot_x] # Mamba cache 索引 (210 tokens 的状态)

场景 2: 第二个请求 (相同前缀 200 tokens)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
输入: 210 tokens
前缀匹配: 200 tokens

匹配过程:
1. 遍历树,找到节点 [0-209]
2. 计算匹配长度: 200 tokens
3. 对齐到 chunk 边界:
   fla_chunk_aligned_seqlen = (200 // 64) * 64 = 192

实际使用:
- KV cache: 200 tokens (完全复用)
- Mamba cache: 192 tokens (只能复用前 192 个)

新计算:
- Token 192-209: 需要重新计算
- Token 210-219: 需要重新计算

为什么 Mamba cache 只能复用 192?

因为 Linear Attention 的状态是在 chunk 边界输出的:

  • h[0]: 对应 token 0-63 的状态
  • h[1]: 对应 token 64-127 的状态
  • h[2]: 对应 token 128-191 的状态
  • last_recurrent_state: 对应 token 0-209 的状态

但要从位置 192 开始计算,需要 h[2] (token 191 的状态),而不是 last_recurrent_state。

场景 3: Decode 阶段

1
2
3
4
5
6
7
8
9
10
11
12
当前状态: 210 tokens
每步 decode: 1 token

在 no_buffer 模式下:
- 不使用 mamba_ping_pong_track_buffer
- 不追踪 mamba_last_track_seqlen
- 直接使用 req.mamba_pool_idx 的状态

每步 decode:
- 更新 conv_state (in-place)
- 更新 ssm_state (in-place)
- 不需要额外追踪

为什么看起来是 Chunk Prefill Size 粒度?

实际观察

当用户使用默认参数时:

1
2
3
# 默认配置
mamba_scheduler_strategy = "auto"  # → "no_buffer"
chunked_prefill_size = 2048 (或根据 GPU 设置)

实际缓存行为:

  1. Prefill 阶段 (例如输入 2048 tokens):
    • 保存: 2048 tokens 的完整状态
    • 树节点: [0-2047]
  2. 后续请求 (相同前缀):
    • KV cache: 精确匹配任意长度
    • Mamba cache: 向下对齐到 64 的倍数
      • 例如: 匹配 1500 tokens → 实际复用 1472 tokens (1500 // 64 * 64)
  3. 为什么看起来是 chunked_prefill_size 粒度?
    • 因为通常 prefill 的输入长度都是 chunked_prefill_size 的倍数
    • 例如: 2048, 4096, 8192…
    • 这些数字本身就能被 64 整除
    • 所以看不到对齐的损失

真实的对齐损失

1
2
3
4
5
6
7
8
# 输入长度不是 64 的倍数
输入: 2100 tokens
对齐后: 1952 tokens (2100 // 64 * 64)
损失: 148 tokens

输入: 210 tokens
对齐后: 192 tokens
损失: 18 tokens

解决方案

方案 1: 使用 extra_buffer 策略 (推荐)

1
2
3
4
5
# 启动时指定
python -m sglang.launch_server \
    --model <model_path> \
    --mamba-scheduler-strategy extra_buffer \
    ...

效果:

  • 启用 ping-pong buffer
  • 以 FLA_CHUNK_SIZE (64) 为粒度保存缓存
  • 更精确的缓存复用
  • 更少的重复计算

代价:

  • 额外的内存开销 (ping-pong buffer)
  • 每隔 mamba_track_interval (默认 256) 进行状态追踪

方案 2: 调整输入长度对齐

在 extra_buffer 模式下,确保 prefill 输入长度是 64 的倍数:

1
2
3
4
5
6
# 例如: padding 到 64 的倍数
input_len = 2100
aligned_len = ((input_len + 63) // 64) * 64  # 2112

# 添加 padding tokens
padded_input = input_ids + [pad_token_id] * (aligned_len - input_len)

方案 3: 接受默认行为

在大多数场景下,“no_buffer” 策略已经足够好:

  • KV cache 仍然是 token 粒度的,可以精确匹配
  • Mamba cache 的对齐损失通常很小 (< 3%)
  • 不需要额外的内存开销

代码位置索引

功能 文件路径 行号
策略选择 server_args.py 767-770
enable_mamba_extra_buffer server_args.py 4538-4539
缓存长度计算 mamba_radix_cache.py 589-593
对齐计算 mamba_radix_cache.py 955-963
mamba_track_interval server_args.py 474
追踪更新 scheduler_output_processor_mixin.py 467-491

总结

  1. 默认策略是 “no_buffer”,不是 “extra_buffer”
  2. FLA_CHUNK_SIZE (64) 是硬编码的,无法通过参数修改
  3. 对齐损失发生在匹配时,而不是保存时
  4. 要获得精确的 chunk 粒度缓存,需要显式设置 --mamba-scheduler-strategy extra_buffer
  5. 在大多数情况下,默认的 “no_buffer” 策略已经提供了良好的性能平衡

关键参数

1
2
3
4
5
# 相关启动参数
--mamba-scheduler-strategy [auto|no_buffer|extra_buffer]  # 默认: auto (→ no_buffer)
--mamba-track-interval 256                               # extra_buffer 模式下的追踪间隔
--chunked-prefill-size 2048                               # Prefill chunk 大小
--max-mamba-cache-size 4096                               # Mamba cache 最大 slot 数

注意: FLA_CHUNK_SIZE (64) 是内核常量,无法通过启动参数修改。修改它需要重新编译 Triton kernel。