qwennext-forward

qwennext-forward
gogongxt初始化
1
2
3
4
5
6
7
8
9
10
11
12
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
size, dtype=torch.int32, device=self.device
)
if enable_mamba_extra_buffer
self.req_index_to_mamba_ping_pong_track_buffer_mapping: torch.Tensor = (
torch.zeros(
(size, self.mamba_ping_pong_track_buffer_size), # size是max running request self.mamba_ping_pong_track_buffer_size默认是256
dtype=torch.int32,
device=self.device,
)
)Qwen3Next Prefill Cache 可视化指南
完整流程数据流图
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
┌─────────────────────────────────────────────────────────────────────────────┐
│ 请求进入 Prefill 阶段 │
│ 输入: 210 tokens, 无缓存 │
└─────────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────────┐
│ Step 1: prepare_for_extend (schedule_batch.py:1612-1664) │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 1.1 匹配 Prefix Cache │ │
│ │ match_prefix(RadixKey(token_ids)) │ │
│ │ → prefix_indices: [] (空) │ │
│ │ → mamba_branching_seqlen: None │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 1.2 计算追踪信息 │ │
│ │ extend_input_len = 210 │ │
│ │ mask = (210 // 64) * 64 > 0 → 192 > 0 → True │ │
│ │ mamba_track_seqlen = 210 │ │
│ │ mamba_track_seqlen_aligned = 192 (向下对齐到64) │ │
│ │ mamba_track_indices = [slot_100] (分配的track slot) │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 1.3 构建输入 Tensor │ │
│ │ input_ids = [token_0, ..., token_209] (210 tokens) │ │
│ │ query_start_loc = [0, 210] │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────────┐
│ Step 2: init_forward_metadata (hybrid_linear_attn_backend.py:259-378) │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 2.1 计算Conv追踪索引 (_init_track_conv_indices) │ │
│ │ lens_to_track = 192 │ │
│ │ aligned_len = 192 │ │
│ │ conv_state_len = 4 (假设 conv_kernel=4) │ │
│ │ start_indices = 0 + 192 - 4 = 188 │ │
│ │ track_conv_indices = [188, 189, 190, 191] │ │
│ │ → 需要提取位置 188-191 的输入作为 conv state │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 2.2 计算SSM追踪索引 (_init_track_ssm_indices) │ │
│ │ extend_seq_lens = 210 │ │
│ │ num_h_states = (210-1)//64 + 1 = 4 │ │
│ │ lens_to_track = 192 │ │
│ │ is_aligned = (192 % 64 == 0) = True │ │
│ │ → 使用 last_recurrent_state (对齐) │ │
│ │ track_ssm_final_src = [working_slot] │ │
│ │ track_ssm_final_dst = [slot_100] │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────────┐
│ Step 3: forward_extend - Convolution (hybrid_linear_attn_backend.py:997-1023)│
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 3.1 执行 Causal Conv1d │ │
│ │ mixed_qkv.shape = [210, dim] │ │
│ │ conv_states.shape = [pool_size, dim, 4] │ │
│ │ cache_indices = [working_slot] │ │
│ │ has_initial_state = False (无prefix) │ │
│ │ │ │
│ │ mixed_qkv_conv = causal_conv1d_fn( │ │
│ │ mixed_qkv, │ │
│ │ conv_weights, │ │
│ │ conv_states, │ │
│ │ has_initial_state=False, │ │
│ │ cache_indices=[working_slot] │ │
│ │ ) │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 3.2 追踪 Conv State 到 Track Slot │ │
│ │ # 提取位置 188-191 的输入 │ │
│ │ mixed_qkv_to_track = mixed_qkv_conv[:, [188,189,190,191]] │ │
│ │ │ │
│ │ # 写入 track slot │ │
│ │ conv_states[slot_100] = mixed_qkv_to_track │ │
│ │ │ │
│ │ conv_states[slot_100] 现在包含 [token_188, token_189, │ │
│ │ token_190, token_191] │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────────┐
│ Step 4: forward_extend - Linear Attention (hybrid_linear_attn_backend.py:1068│
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 4.1 FLA Chunk 计算 │ │
│ │ chunk_gated_delta_rule( │ │
│ │ q, k, v, g, beta, │ │
│ │ initial_state=ssm_states, │ │
│ │ initial_state_indices=[working_slot] │ │
│ │ ) │ │
│ │ │ │
│ │ FLA Kernel 内部: │ │
│ │ ┌────────────────────────────────────────────────────────────┐ │ │
│ │ │ 序列分块 (210 tokens): │ │ │
│ │ │ │ │ │
│ │ │ Chunk 0: [0-63] 64 tokens │ │ │
│ │ │ Chunk 1: [64-127] 64 tokens │ │ │
│ │ │ Chunk 2: [128-191] 64 tokens │ │ │
│ │ │ Chunk 3: [192-209] 18 tokens (不完整) │ │ │
│ │ │ │ │ │
│ │ │ 状态输出: │ │ │
│ │ │ h[0]: Chunk 0 后的状态 (对应位置 63) │ │ │
│ │ │ h[1]: Chunk 1 后的状态 (对应位置 127) │ │ │
│ │ │ h[2]: Chunk 2 后的状态 (对应位置 191) │ │ │
│ │ │ h[3]: Chunk 3 后的状态 (对应位置 209,不完整) │ │ │
│ │ │ │ │ │
│ │ │ last_recurrent_state: 位置 209 的完整状态 │ │ │
│ │ │ (in-place 更新到 ssm_states[working_slot])│ │ │
│ │ └────────────────────────────────────────────────────────────┘ │ │
│ │ │ │
│ │ 返回: │ │
│ │ h.shape = [1, 4, H, K, V] (4个chunks) │ │
│ │ last_recurrent_state.shape = [1, H, K, V] │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 4.2 追踪 SSM State 到 Track Slot │ │
│ │ # 因为 is_aligned = True,从 last_recurrent_state 获取 │ │
│ │ ssm_states[slot_100] = ssm_states[working_slot] │ │
│ │ │ │
│ │ ssm_states[slot_100] 现在包含位置 192 的状态 (对齐到chunk边界) │ │
│ │ 注意: 不是 209! │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────────┐
│ Step 5: cache_finished_req (mamba_radix_cache.py:480-572) │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 5.1 计算缓存长度 │ │
│ │ kv_committed_len = 210 │ │
│ │ token_ids = [token_0, ..., token_209] │ │
│ │ kv_indices = [kv_slot_0, ..., kv_slot_209] │ │
│ │ │ │
│ │ # no_buffer 策略 (默认) │ │
│ │ cache_len = len(token_ids) = 210 │ │
│ │ │ │
│ │ # 但实际mamba cache只到192 │
│ │ mamba_last_track_seqlen = 192 │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 5.2 Fork Mamba Cache │ │
│ │ # 从 working_slot fork 到持久 cache slot │ │
│ │ mamba_value = [working_slot] │ │
│ │ mamba_value_forked = mamba_pool.fork_from([working_slot]) │ │
│ │ │ │
│ │ # 分配 slot_200,复制状态 │ │
│ │ conv_states[slot_200] = conv_states[working_slot] │ │
│ │ = [token_188, 189, 190, 191] │ │
│ │ ssm_states[slot_200] = ssm_states[working_slot] │ │
│ │ = 状态@位置192 │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 5.3 插入 Radix Tree │ │
│ │ insert( │ │
│ │ RadixKey([token_0, ..., token_209]), │ │
│ │ [kv_slot_0, ..., kv_slot_209], │ │
│ │ [slot_200] │ │
│ │ ) │ │
│ │ │ │
│ │ Radix Tree: │ │
│ │ Root │ │
│ │ └─ [0-209] │ │
│ │ value: [kv_slot_0, ..., kv_slot_209] (210 tokens, KV cache) │ │
│ │ mamba_value: [slot_200] (只包含到位置192的状态!) │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 5.4 释放资源 │ │
│ │ # 释放 working_slot 的 mamba cache (no_buffer 策略) │ │
│ │ free_mamba_cache = False │ │
│ │ req.mamba_pool_idx = None │ │
│ │ │ │
│ │ # KV cache 已在 token_to_kv_pool 中 │ │
│ │ # Mamba cache (slot_200) 在 radix tree 中 │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────────┐
│ 最终结果 │
│ │
│ Radix Tree Cache: │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ Root │ │
│ │ └─ [0-209] │ │
│ │ │ │
│ │ KV Cache (token粒度): │ │
│ │ value: [kv_slot_0, ..., kv_slot_209] (210 tokens) │ │
│ │ → 可以精确匹配任意长度前缀 │ │
│ │ │ │
│ │ Mamba Cache (chunk粒度): │ │
│ │ mamba_value: [slot_200] │ │
│ │ │ │
│ │ slot_200.conv: [token_188, 189, 190, 191] (conv state) │ │
│ │ slot_200.ssm: 状态@位置192 (ssm state) │ │
│ │ │ │
│ │ → 只能缓存到位置192 (192 = 210//64*64) │ │
│ │ → 位置192-209的18个token无法被mamba cache复用 │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │
│ 对齐损失: 210 - 192 = 18 tokens (约8.6%) │
└─────────────────────────────────────────────────────────────────────────────┘第二次请求 (相同前缀 + 额外token)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
┌─────────────────────────────────────────────────────────────────────────────┐
│ 请求2: 230 tokens (相同前缀) │
│ 前210个token与请求1相同,后20个token不同 │
└─────────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────────┐
│ Step 1: prepare_for_extend - 匹配 Prefix Cache │
│ │
│ match_prefix(RadixKey([token_0, ..., token_229])) │
│ │
│ 匹配过程: │
│ 1. 遍历 Radix Tree,找到节点 [0-209] │
│ 2. 检查 mamba_value: 存在 ([slot_200]) │
│ 3. 计算匹配长度: 210 tokens │
│ 4. 计算 mamba_branching_seqlen: │
│ │
│ 匹配路径长度 = 210 │
│ 最后一个有 mamba_value 的节点 = [0-209] │
│ best_value_len = 210 │
│ │
│ len(value) = best_value_len → 不需要向下对齐 │
│ mamba_branching_seqlen = None │
│ │
│ 结果: │
│ prefix_indices: [kv_slot_0, ..., kv_slot_209] (210 tokens) │
│ last_node: [0-209] 节点 │
│ mamba_branching_seqlen: None │
└─────────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────────┐
│ Step 2: Copy-on-Write (fork mamba cache) │
│ │
│ cow_mamba = True, req.mamba_pool_idx = None │
│ │
│ 1. 分配新的 working slot: slot_300 │
│ 2. Fork mamba cache: │
│ conv_states[slot_300] = conv_states[slot_200] │
│ ssm_states[slot_300] = ssm_states[slot_200] │
│ 3. req.mamba_pool_idx = slot_300 │
│ │
│ 现在请求有独立的 mamba cache 副本,不会破坏 radix tree 中的cache │
└─────────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────────┐
│ Step 3: prepare_for_extend - 计算追踪信息 │
│ │
│ prefix_indices = [kv_slot_0, ..., kv_slot_209] (210 tokens) │
│ extend_input_len = 20 │
│ │
│ 计算追踪: │
│ mask = (20 // 64) * 64 > 0 → 0 > 0 → False │
│ mamba_track_mask = [False] # 不追踪 (长度不足64) │
│ │
│ 结果: 只处理新tokens,不保存新的mamba cache │
└─────────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────────┐
│ Step 4: forward_extend - 使用已有缓存 │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 4.1 Convolution (extend新tokens) │ │
│ │ mixed_qkv = [token_210, ..., token_229] (20 tokens) │ │
│ │ cache_indices = [slot_300] (fork的slot) │ │
│ │ has_initial_state = True # 有prefix! │ │
│ │ │ │
│ │ conv_states[slot_300]: [token_188, 189, 190, 191] │ │
│ │ │ │
│ │ causal_conv1d_fn( │ │
│ │ mixed_qkv, │ │
│ │ ..., │ │
│ │ has_initial_state=True, # 从slot_300的conv state开始 │ │
│ │ cache_indices=[slot_300] │ │
│ │ ) │ │
│ │ │ │
│ │ conv_states[slot_300] 被in-place更新 │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ 4.2 Linear Attention │ │
│ │ chunk_gated_delta_rule( │ │
│ │ q, k, v, g, beta, │ │
│ │ initial_state=ssm_states, │ │
│ │ initial_state_indices=[slot_300] # 从slot_300的ssm state开始 │ │
│ │ ) │ │
│ │ │ │
│ │ ssm_states[slot_300]: 状态@位置192 │ │
│ │ │ │
│ │ 从位置192继续计算到230 │ │
│ │ ssm_states[slot_300] 被in-place更新 │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────────┐
│ Step 5: cache_finished_req (不缓存,因为长度不足64) │
│ │
│ cache_len = 230 # no_buffer策略 │
│ │
│ 但 mamba_track_mask = False,所以实际不保存新的mamba cache │
│ 只复用了之前的 cache (slot_200) │
└─────────────────────────────────────────────────────────────────────────────┘关键数据结构变化
Conv State 变化
1
2
3
4
5
6
7
8
9
10
11
12
13
初始: conv_states[slot] = [0, 0, 0, 0] (conv_kernel=4)
Prefill 210 tokens:
[0, 1, 2, ..., 209]
Conv state 滑动窗口:
位置 63: [60, 61, 62, 63]
位置 127: [124, 125, 126, 127]
位置 191: [188, 189, 190, 191] ← 追踪这个!
位置 209: [206, 207, 208, 209]
追踪到 conv_states[slot_100]:
conv_states[slot_100] = [token_188, token_189, token_190, token_191]SSM State 变化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
FLA Chunk 计算:
输入: 210 tokens
Chunk 0 [0-63]: h[0] = 状态@63
Chunk 1 [64-127]: h[1] = 状态@127
Chunk 2 [128-191]: h[2] = 状态@191
Chunk 3 [192-209]: h[3] = 状态@209 (不完整)
last_recurrent_state = 状态@209 (完整)
in-place 更新到 ssm_states[working_slot]
追踪 (is_aligned=True):
lens_to_track = 192 (对齐到64)
从 last_recurrent_state 获取
ssm_states[slot_100] = ssm_states[working_slot]
= 状态@209
但实际只能使用到位置192 (chunk边界)Radix Tree 结构
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
请求1完成后:
Root
└─ [0-209]
value: [kv_0, ..., kv_209] (210 tokens)
mamba_value: [slot_200]
slot_200:
conv: [token_188, token_189, token_190, token_191]
ssm: 状态@位置192
请求2完成后 (如果长度足够):
Root
├─ [0-209]
│ value: [kv_0, ..., kv_209]
│ mamba_value: [slot_200]
└─ [0-229] (假设长度足够,新增分支)
value: [kv_0, ..., kv_229]
mamba_value: [slot_301]
slot_301:
conv: [token_xxx, ..., token_yyy]
ssm: 状态@位置xxx常见问题
Q1: 为什么 Mamba Cache 只能缓存到192而不是210?
A: 因为 FLA (Flash Linear Attention) 的状态只在 chunk 边界 (64的倍数) 输出。
1
2
3
4
5
6
7
8
9
10
210 tokens 分块:
Chunk 0: [0-63] → 状态@63 ✓ 可缓存
Chunk 1: [64-127] → 状态@127 ✓ 可缓存
Chunk 2: [128-191] → 状态@191 ✓ 可缓存
Chunk 3: [192-209] → 状态@209 ✗ 不是完整chunk
虽然 last_recurrent_state 包含到209的完整状态,
但它包含了不对齐的部分 (192-209),无法作为其他请求的初始状态。
其他请求要从位置192开始,需要状态@191,而不是状态@209。Q2: 对齐损失有多大?
A: 取决于输入长度对 64 取模的结果。
1
2
3
4
5
6
7
8
9
10
11
12
输入长度 对齐长度 损失 损失率
64 64 0 0%
100 64 36 36%
128 128 0 0%
150 128 22 14.7%
192 192 0 0%
210 192 18 8.6%
256 256 0 0%
300 256 44 14.7%
512 512 0 0%
平均损失率约为 3-5% (假设随机长度)Q3: no_buffer vs extra_buffer 策略有什么区别?
A:
| 特性 | no_buffer (默认) | extra_buffer |
|---|---|---|
| cache_len | 整个序列长度 | 对齐长度 |
| 内存开销 | 无额外开销 | ping-pong buffer (2x) |
| 缓存粒度 | 保存整个序列,但使用时对齐 | 只保存对齐部分 |
| 适用场景 | 内存受限 | 需要更精确的缓存控制 |
示例:
1
2
3
4
5
6
7
8
9
10
11
输入: 210 tokens
no_buffer:
保存: 210 tokens 的 KV cache
保存: 210 tokens 的 mamba cache (实际只有192可用)
内存: 1x mamba cache
extra_buffer:
保存: 210 tokens 的 KV cache
保存: 192 tokens 的 mamba cache
内存: 2x mamba cache (ping-pong buffer)Q4: 为什么需要 Copy-on-Write?
A: 因为 Mamba Cache 是 in-place 更新的。
1
2
3
4
5
6
7
8
9
10
没有 CoW:
请求A和请求B共享 slot_200
请求A decode → 更新 ssm_states[slot_200] (in-place)
请求B decode → 从 ssm_states[slot_200] 读取 → 得到请求A的状态! ✗
有 CoW:
请求A和请求B共享 slot_200
请求A decode → fork to slot_300 → 更新 ssm_states[slot_300]
请求B decode → fork to slot_301 → 更新 ssm_states[slot_301]
两者互不干扰 ✓性能优化建议
1. 输入长度对齐
如果可以控制输入长度,尽量对齐到 64 的倍数:
1
2
3
4
# Client端padding
input_len = 210
aligned_len = ((input_len + 63) // 64) * 64 # 256
padded_input = input_ids + [pad_token_id] * (aligned_len - input_len)2. 使用 extra_buffer 策略
如果内存充足,使用 extra_buffer 策略:
1
2
3
4
python -m sglang.launch_server \
--model-path <model> \
--mamba-scheduler-strategy extra_buffer \
...3. 批量相似请求
尽量批量处理有相同前缀的请求,最大化 prefix cache 复用率。
总结
- FLA Chunk 机制 是理解 Mamba Cache 的关键
- 对齐损失 是不可避免的,但可以通过输入对齐最小化
- Copy-on-Write 保证多请求共享前缀时的正确性
- Tombstone 机制 允许 KV 和 Mamba cache 独立管理
- 实际缓存粒度 是 64 tokens,不是整个序列长度
希望这份可视化指南能帮助你更好地理解 Qwen3Next 的 Prefill Cache 机制!
评论
匿名评论隐私政策




