qwen3next-all

qwen3next-all
gogongxt缓存长度必须是 64 的整数倍。如果分叉位置是 300,则缓存位置会回退到
- 如果 当前路径上从未发生分叉 → prefix = 最后一个 64 对齐点
- 如果 中途发生过分叉 → prefix = 分叉点的最后一个 64 对齐点
- CONTENT = “1” * 256 + “2” * 256 # prefix 0
- CONTENT = “1” * 256 + “3” * 256 # prefix 0
- CONTENT = “1” * 256 + “4” * 256 # prefix 256
- CONTENT = “1” * 256 + “2” * 256 + “5” * 256 # prefix 512
- CONTENT = “1” * 256 + “4” * 256 + “5” * 256 # prefix 512
- CONTENT = “1” * 256 + “3” * 256 + “5” * 256 # prefix 256
- CONTENT = “1” * 256 + “3” * 256 # prefix 256
- CONTENT = “1” * 256 + “3” * 256 # prefix 512
graph TD
subgraph S1["R1: (1+2) -> Prefix 0"]
r1((Root)) --- n1_1["1 (256)"] --- n1_2["2 (256)"]
style n1_2 fill:#4CAF50,color:#fff
end
subgraph S2["R2: (1+3) -> Prefix 0"]
r2((Root)) --- n2_1["1 (256)"]
n2_1 --- n2_2["2 (256)"]
n2_1 --- n2_3["3 (256)"]
style n2_1 stroke-dasharray: 5 5
style n2_2 fill:#4CAF50,color:#fff
style n2_1 fill:#4CAF50,color:#fff
end
subgraph S3["R3: (1+4) -> Prefix 256"]
r3((Root)) --- n3_1["1 (256)"]
n3_1 --- n3_2["2 (256)"]
n3_1 --- n3_3["3 (256)"]
n3_1 --- n3_4["4 (256)"]
style n3_1 fill:#4CAF50,color:#fff
style n3_2 fill:#4CAF50,color:#fff
style n3_4 fill:#4CAF50,color:#fff
end
subgraph S4["R4: (1+2+5) -> Prefix 512"]
r4((Root)) --- n4_1["1 (256)"]
n4_1 --- n4_2["2 (256)"]
n4_2 --- n4_5["5 (256)"]
n4_1 --- n4_3["3 (256)"]
n4_1 --- n4_4["4 (256)"]
style n4_1 fill:#4CAF50,color:#fff
style n4_2 fill:#4CAF50,color:#fff
style n4_4 fill:#4CAF50,color:#fff
style n4_5 fill:#4CAF50,color:#fff
end
subgraph S5["R5: (1+4+5) -> Prefix 512"]
r5((Root)) --- n5_1["1 (256)"]
n5_1 --- n5_2["2 (256)"]
n5_2 --- n5_2_1["5 (256)"]
n5_1 --- n5_3["3 (256)"]
n5_1 --- n5_4["4 (256)"]
n5_4 --- n5_5["5 (256)"]
style n5_1 fill:#4CAF50,color:#fff
style n5_4 fill:#4CAF50,color:#fff
style n5_2 fill:#4CAF50,color:#fff
style n5_2_1 fill:#4CAF50,color:#fff
style n5_5 fill:#4CAF50,color:#fff
end
subgraph S6["R6: (1+3+5) -> Prefix 256"]
r6((Root)) --- n6_1["1 (256)"]
n6_1 --- n6_2["2 (256)"]
n6_2 --- n6_2_1["5 (256)"]
n6_1 --- n6_3["3 (256)"]
n6_3 --- n6_5["5 (256)"]
n6_1 --- n6_4["4 (256)"]
n6_4 --- n6_4_1["5 (256)"]
style n6_2 fill:#4CAF50,color:#fff
style n6_2_1 fill:#4CAF50,color:#fff
style n6_1 fill:#4CAF50,color:#fff
style n6_4 fill:#4CAF50,color:#fff
style n6_4_1 fill:#4CAF50,color:#fff
style n6_5 fill:#4CAF50,color:#fff
end
subgraph S7["R7: (1+3) -> Prefix 256"]
r7((Root)) --- n7_1["1 (256)"]
n7_1 --- n7_2["2 (256)"]
n7_2 --- n7_2_1["5 (256)"]
n7_1 --- n7_3["3 (256)"]
n7_3 --- n7_5["5 (256)"]
n7_1 --- n7_4["4 (256)"]
n7_4 --- n7_4_1["5 (256)"]
style n7_1 fill:#4CAF50,color:#fff
style n7_2 fill:#4CAF50,color:#fff
style n7_2_1 fill:#4CAF50,color:#fff
style n7_3 fill:#4CAF50,color:#fff
style n7_5 fill:#4CAF50,color:#fff
style n7_4 fill:#4CAF50,color:#fff
style n7_4_1 fill:#4CAF50,color:#fff
end
subgraph S8["R8: (1+3) -> Prefix 512"]
r8((Root)) --- n8_1["1 (256)"]
n8_1 --- n8_2["..."]
n8_1["1 (256)"] --- n8_3["3 (256)"]
n8_3["3 (256)"] --- n8_3_1["..."]
n8_1 --- n8_4["..."]
style n8_1 fill:#4CAF50,color:#fff
style n8_2 fill:#4CAF50,color:#fff
style n8_3 fill:#4CAF50,color:#fff
style n8_4 fill:#4CAF50,color:#fff
style n8_3_1 fill:#4CAF50,color:#fff
end
S1 --> S2 --> S3 --> S4 --> S5 --> S6 --> S7 --> S8
1. 总体架构
1.1 组件关系图
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
HybridReqToTokenPool (请求到Token的映射池)
│
├─> ReqToTokenPool (基础类)
│ └─> req_to_token: [size, max_context_len]
│ # 映射请求索引到token位置索引
│
├─> MambaPool (Mamba状态池)
│ │
│ ├─> State (正常状态)
│ │ ├─> conv: List[torch.Tensor] # 卷积状态
│ │ │ Shape: [num_layers, size + 1, conv_dim, kernel_size]
│ │ └─> temporal: torch.Tensor # SSM时序状态
│ │ Shape: [num_layers, size + 1, num_heads, head_dim, 2]
│ │
│ ├─> SpeculativeState (投机采样状态, 可选)
│ │ ├─> intermediate_ssm: torch.Tensor
│ │ └─> intermediate_conv_window: List[torch.Tensor]
│ │
│ └─> free_slots: torch.Tensor # 空闲slot索引
│
├─> req_index_to_mamba_index_mapping: torch.Tensor
│ # Shape: [size]
│ # 映射请求索引 → mamba slot索引
│
└─> req_index_to_mamba_ping_pong_track_buffer_mapping: torch.Tensor (extra_buffer模式,可选)
# Shape: [size, 2]
# 映射请求索引 → [ping_slot_idx, pong_slot_idx]1.2 内存布局
1
2
3
4
5
6
7
8
9
10
11
12
13
14
GPU Memory Layout
├─────────────────────────────────────────────────────────┐
│ Model Weights │
├─────────────────────────────────────────────────────────┤
│ KV Cache (MHATokenToKVPool / MLATokenToKVPool) │
│ - k_buffer: [num_layers, size, num_heads, head_dim] │
│ - v_buffer: [num_layers, size, num_heads, head_dim] │
├─────────────────────────────────────────────────────────┤
│ Mamba Cache (MambaPool) │
│ - conv_state: List[torch.Tensor] │
│ [num_layers, size+1, conv_dims] │
│ - temporal_state: torch.Tensor │
│ [num_layers, size+1, temporal_dims] │
└─────────────────────────────────────────────────────────┘1.3 关键参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
size |
int | - | Mamba pool大小 (slot数量) |
spec_state_size |
int | - | 投机采样状态池大小 |
enable_mamba_extra_buffer |
bool | False | 是否启用ping-pong buffer |
mamba_ping_pong_track_buffer_size |
int | 2 | Ping-pong buffer大小 |
speculative_num_draft_tokens |
int | None | 投机采样draft token数量 |
2. MambaPool 类详解
文件位置:
python/sglang/srt/mem_cache/memory_pool.py:128-330
2.1 类定义和状态数据结构
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class MambaPool:
"""Mamba状态内存池
功能:
1. 管理线性注意力层的SSM状态 (卷积状态 + 时序状态)
2. 提供slot分配/释放机制
3. 支持状态复制 (copy_from) 和写时复制 (fork_from)
4. 可选的投机采样状态缓存
"""
@dataclass(frozen=True, kw_only=True)
class State:
"""基础Mamba状态数据结构
Attributes:
conv: 卷积状态列表,每个元素是一个tensor
每个卷积状态可能有多维 (如不同层的卷积维度不同)
temporal: SSM递归状态
"""
conv: List[torch.Tensor]
temporal: torch.Tensor
def at_layer_idx(self, layer: int) -> "State":
"""提取指定层的状态
Args:
layer: 层索引
Returns:
只包含指定层的新State对象
示例:
>>> # 获取第5层的状态
>>> layer_5_state = mamba_cache.at_layer_idx(5)
>>> # layer_5_state.conv 是 [conv[5]]
>>> # layer_5_state.temporal 是 temporal[5]
"""
kwargs = {}
for k, v in vars(self).items():
if k == "conv" or k == "intermediate_conv_window":
# conv是列表,每个元素对应一层
kwargs[k] = [conv[layer] for conv in v]
else:
# temporal是tensor,第一个维度是层
kwargs[k] = v[layer]
return type(self)(**kwargs)
def mem_usage_bytes(self) -> int:
"""计算当前状态的内存占用 (字节)
Returns:
所有tensor的总字节数
"""
return sum(
get_tensor_size_bytes(getattr(self, f.name))
for f in dataclasses.fields(self)
)2.2 SpeculativeState (投机采样状态)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
@dataclass(frozen=True, kw_only=True)
class SpeculativeState(State):
"""投机采样时的扩展状态
在投机采样模式下,需要缓存中间状态用于验证:
- intermediate_ssm: 每个draft token的SSM状态
- intermediate_conv_window: 每个draft token的卷积窗口
Attributes:
intermediate_ssm: 中间SSM状态
Shape: [num_layers, size+1, num_draft_tokens, num_heads, head_dim, 2]
intermediate_conv_window: 中间卷积窗口
每个conv shape: [num_layers, size+1, num_draft_tokens, dim, kernel_size]
"""
intermediate_ssm: torch.Tensor
intermediate_conv_window: List[torch.Tensor]2.3 初始化方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def __init__(
self,
*,
size: int, # 池大小 (slot数量)
spec_state_size: int, # 投机采样状态池大小
cache_params: BaseLinearStateParams, # 缓存参数 (包含shape和dtype)
device: str, # 设备 ("cuda" 或 "cpu")
enable_memory_saver: bool = False, # 是否启用内存节省模式
speculative_num_draft_tokens: Optional[int] = None, # Draft token数量
):
"""初始化Mamba内存池
Args:
size: 最大支持的并发请求数
spec_state_size: 投机采样状态池大小
cache_params: 包含以下属性:
- shape.conv: List[tuple] - 卷积状态shape列表
- shape.temporal: tuple - 时序状态shape
- dtype.conv: torch.dtype - 卷积状态数据类型
- dtype.temporal: torch.dtype - 时序状态数据类型
- layers: List[int] - Mamba层ID列表
device: 设备字符串
enable_memory_saver: 是否使用TorchMemorySaverAdapter
speculative_num_draft_tokens: 投机采样draft token数量
"""
# ===== 1. 提取配置参数 =====
conv_state_shape = cache_params.shape.conv # 卷积状态shape列表
temporal_state_shape = cache_params.shape.temporal # 时序状态shape
conv_dtype = cache_params.dtype.conv # 卷积状态数据类型 (float32)
ssm_dtype = cache_params.dtype.temporal # SSM状态数据类型 (float32)
# 内存节省适配器 (用于CPU卸载)
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
num_mamba_layers = len(cache_params.layers) # Mamba层数量
self.size = size # 记录池大小
self.device = device # 记录设备
# ===== 2. NVLink disaggregation 支持 =====
# 用于多GPU分布式场景,使用自定义内存池
self.enable_custom_mem_pool, self.custom_mem_pool, _ = (
maybe_init_custom_mem_pool(device=self.device)
)
# ===== 3. 分配内存 =====
# 使用内存池上下文管理器 (如果启用custom_mem_pool)
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE), (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
else nullcontext()
):
# ===== 3.1 分配卷积状态 =====
# conv_state 是一个列表,因为不同层的卷积维度可能不同
# 每个 tensor shape: [num_layers, size + 1, conv_dim, conv_kernel_size]
# 注意: size + 1 是因为 slot 0 用作 padding (避免边界检查)
conv_state = [
torch.zeros(
size=(num_mamba_layers, size + 1) + conv_shape,
dtype=conv_dtype,
device=device,
)
for conv_shape in conv_state_shape
]
# ===== 3.2 分配时序状态 =====
# temporal_state shape: [num_layers, size + 1, num_heads, head_dim, 2]
# 最后的维度2存储:
# - index 0: h (主状态)
# - index 1: 辅助状态 (用于门控等)
temporal_state = torch.zeros(
size=(num_mamba_layers, size + 1) + temporal_state_shape,
dtype=ssm_dtype,
device=device,
)
# ===== 3.3 分配投机采样状态 (可选) =====
if speculative_num_draft_tokens is not None:
# Cache中间SSM状态: 每个draft token一个状态
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
# 说明:
# - num_layers: Mamba层数量
# - size + 1: 池大小 (slot 0用于padding)
# - speculative_num_draft_tokens: draft token数量
# - HV: value头数量 (linear_num_value_heads)
# - K: key头维度 (linear_key_head_dim)
# - V: value头维度 (linear_value_head_dim)
intermediate_ssm_state_cache = torch.zeros(
size=(
num_mamba_layers,
spec_state_size + 1,
speculative_num_draft_tokens,
temporal_state_shape[0],
temporal_state_shape[1],
temporal_state_shape[2],
),
dtype=ssm_dtype,
device="cuda",
)
# Cache中间卷积窗口: 每个draft token一个窗口
# 存储最近K-1个输入 (K = conv_kernel_size)
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache = [
torch.zeros(
size=(
num_mamba_layers,
spec_state_size + 1,
speculative_num_draft_tokens,
conv_shape[0],
conv_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
for conv_shape in conv_state_shape
]
# 创建SpeculativeState对象
self.mamba_cache = self.SpeculativeState(
conv=conv_state,
temporal=temporal_state,
intermediate_ssm=intermediate_ssm_state_cache,
intermediate_conv_window=intermediate_conv_window_cache,
)
# 记录内存使用情况
logger.info(
f"Mamba Cache is allocated. "
f"max_mamba_cache_size: {size}, "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
)
else:
# 非投机采样模式: 创建基础State对象
self.mamba_cache = self.State(
conv=conv_state,
temporal=temporal_state
)
# 记录内存使用情况
logger.info(
f"Mamba Cache is allocated. "
f"max_mamba_cache_size: {size}, "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
)
# ===== 3.4 初始化空闲slot列表 =====
# free_slots: 可用的slot索引列表
# 初始时所有slot都可用: [0, 1, 2, ..., size-1]
# 注意: slot 0 实际被用作padding,分配时会跳过
self.free_slots = torch.arange(
self.size, dtype=torch.int64, device=self.device
)
# 记录总内存使用
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
self.num_mamba_layers = num_mamba_layers2.4 内存分配和释放
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def available_size(self) -> int:
"""获取可用slot数量
Returns:
当前空闲的slot数量
"""
return len(self.free_slots)
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
"""分配mamba cache slot
Args:
need_size: 需要分配的slot数量 (通常为1)
Returns:
分配的slot索引tensor,如果空间不足返回None
注意:
- 分配时会清零对应的slot
- Slot 0 作为padding,实际分配从slot 1开始
"""
# 检查是否有足够的空闲slot
if need_size > len(self.free_slots):
return None
# 从free_slots前面取need_size个slot
select_index = self.free_slots[:need_size]
# 更新free_slots,移除已分配的slot
self.free_slots = self.free_slots[need_size:]
# 清零分配的slot (避免使用旧数据)
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, select_index] = 0
self.mamba_cache.temporal[:, select_index] = 0
# 再次清零 (确保完全清零,可能是冗余操作)
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, select_index] = 0
self.mamba_cache.temporal[:, select_index] = 0
return select_index
def free(self, free_index: torch.Tensor):
"""释放mamba cache slot
Args:
free_index: 要释放的slot索引tensor
注意:
- 释放时不会清零数据,只是将索引加回free_slots
- 下次分配时会清零
"""
if free_index.numel() == 0:
return
# 将释放的索引追加到free_slots
self.free_slots = torch.cat((self.free_slots, free_index))
def clear(self):
"""清空所有slot (重置池状态)
用途:
- 测试时重置
- 发生错误时恢复
"""
self.free_slots = torch.arange(
self.size, dtype=torch.int64, device=self.device
)2.5 状态复制和写时复制
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
"""复制mamba状态: src → dst
Args:
src_index: 源slot索引
dst_index: 目标slot索引
用途:
- Radix cache中的copy-on-write
- Prefix cache命中时的状态复制
示例:
>>> # 从slot 5复制到slot 10
>>> pool.copy_from(torch.tensor([5]), torch.tensor([10]))
"""
# 复制所有卷积状态
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][
:, src_index
]
# 复制时序状态
self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
:, src_index
]
return
def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]:
"""写时复制: 分配新slot并复制源状态
Args:
src_index: 源slot索引
Returns:
新分配的slot索引,如果分配失败返回None
用途:
- Radix cache的copy-on-write机制
- 多个请求共享相同prefix时fork状态
流程:
1. 分配新slot
2. 从src复制状态到新slot
3. 返回新slot索引
示例:
>>> # Fork slot 5到新slot
>>> new_slot = pool.fork_from(torch.tensor([5]))
>>> # new_slot 现在包含 slot 5 的状态副本
"""
dst_index = self.alloc(1)
if dst_index == None:
return None
self.copy_from(src_index, dst_index)
return dst_index2.6 辅助方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
"""获取所有层的投机采样状态
Returns:
SpeculativeState对象
断言:
当前mamba_cache必须是SpeculativeState类型
"""
assert isinstance(self.mamba_cache, self.SpeculativeState)
return self.mamba_cache
def mamba2_layer_cache(self, layer_id: int) -> State:
"""获取指定层的mamba cache
Args:
layer_id: 层ID (全局层索引)
Returns:
只包含指定层的State对象
示例:
>>> layer_cache = pool.mamba2_layer_cache(5)
>>> # layer_cache.conv 是 [conv[5]]
>>> # layer_cache.temporal 是 temporal[5]
"""
return self.mamba_cache.at_layer_idx(layer_id)
def get_contiguous_buf_infos(self) -> Tuple[List[int], List[int], List[int]]:
"""获取所有状态tensor的内存信息
Returns:
data_ptrs: 每层每个tensor的数据指针列表
data_lens: 每层每个tensor的总字节数
item_lens: 每层每个tensor的单个元素字节数
用途:
- NVLink disaggregation
- 跨GPU内存共享
"""
state_tensors = []
# 将mamba_cache中的所有tensor展平成列表
for field in vars(self.mamba_cache):
value = getattr(self.mamba_cache, field)
if isinstance(value, list):
state_tensors.extend(value)
else:
state_tensors.append(value)
data_ptrs, data_lens, item_lens = [], [], []
# 遍历所有层,收集内存信息
for _, state_tensor in enumerate(state_tensors):
data_ptrs += [
state_tensor[i].data_ptr() for i in range(self.num_mamba_layers)
]
data_lens += [state_tensor[i].nbytes for i in range(self.num_mamba_layers)]
item_lens += [
state_tensor[i][0].nbytes for i in range(self.num_mamba_layers)
]
return data_ptrs, data_lens, item_lens3. HybridReqToTokenPool 类详解
文件位置:
python/sglang/srt/mem_cache/memory_pool.py:332-510
3.1 类定义
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class HybridReqToTokenPool(ReqToTokenPool):
"""混合请求到Token的映射池
继承自 ReqToTokenPool,添加了Mamba cache支持
功能:
1. 映射请求索引 → token位置 (基础功能)
2. 映射请求索引 → mamba slot索引 (新增)
3. 管理ping-pong buffer分配 (extra_buffer模式)
内存布局:
req_to_token: [size, max_context_len]
req_index_to_mamba_index_mapping: [size]
req_index_to_mamba_ping_pong_track_buffer_mapping: [size, 2] (可选)
"""
def __init__(
self,
*,
size: int, # 请求池大小
mamba_size: int, # Mamba pool大小
mamba_spec_state_size: int, # Mamba投机采样状态池大小
max_context_len: int, # 最大上下文长度
device: str, # 设备
enable_memory_saver: bool, # 是否启用内存节省
cache_params: BaseLinearStateParams, # Mamba cache参数
enable_mamba_extra_buffer: bool, # 是否启用ping-pong buffer
speculative_num_draft_tokens: int = None, # 投机采样draft token数
):
# ===== 1. 初始化基类 ReqToTokenPool =====
# 分配 req_to_token: [size, max_context_len]
super().__init__(
size=size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=enable_memory_saver,
)
# ===== 2. 设置ping-pong buffer大小 =====
# 正常模式: 2个slot (ping和pong)
# 投机采样模式: 1个slot (因为已有intermediate状态)
self.mamba_ping_pong_track_buffer_size = (
2 if speculative_num_draft_tokens is None else 1
)
self.enable_mamba_extra_buffer = enable_mamba_extra_buffer
self.enable_memory_saver = enable_memory_saver
# ===== 3. 初始化MambaPool =====
self._init_mamba_pool(
size=mamba_size,
mamba_spec_state_size=mamba_spec_state_size,
cache_params=cache_params,
device=device,
enable_mamba_extra_buffer=enable_mamba_extra_buffer,
speculative_num_draft_tokens=speculative_num_draft_tokens,
)3.2 MambaPool初始化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def _init_mamba_pool(
self,
size: int, # Mamba pool大小
mamba_spec_state_size: int, # 投机采样状态池大小
cache_params: BaseLinearStateParams, # 缓存参数
device: str, # 设备
enable_mamba_extra_buffer: bool, # 是否启用ping-pong buffer
speculative_num_draft_tokens: int = None, # Draft token数量
):
"""初始化Mamba内存池和映射表
功能:
1. 创建MambaPool实例
2. 创建层ID映射表
3. 创建请求到mamba slot的映射表
4. 创建请求到ping-pong buffer的映射表 (extra_buffer模式)
"""
# ===== 1. 创建MambaPool =====
self.mamba_pool = MambaPool(
size=size,
spec_state_size=mamba_spec_state_size,
cache_params=cache_params,
device=device,
enable_memory_saver=self.enable_memory_saver,
speculative_num_draft_tokens=speculative_num_draft_tokens,
)
# ===== 2. 创建层ID映射 =====
# mamba_map: 全局层ID → 连续层索引
# 示例: {5: 0, 10: 1, 15: 2}
# 说明: 只有Mamba层才会在map中
self.mamba_map = {
layer_id: i
for i, layer_id in enumerate(cache_params.layers)
}
self.device = device
# ===== 3. 创建请求索引 → mamba slot索引映射 =====
# Shape: [size]
# 值: 每个请求对应的mamba slot索引
# 初始值: 0 (表示未分配)
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
size, dtype=torch.int32, device=self.device
)
# ===== 4. 创建请求索引 → ping-pong buffer映射 (可选) =====
# Shape: [size, mamba_ping_pong_track_buffer_size]
# - 正常模式: [size, 2]
# - 投机采样: [size, 1]
# 值: 每个请求对应的ping-pong slot索引
# 初始值: 0 (表示未分配)
if enable_mamba_extra_buffer:
self.req_index_to_mamba_ping_pong_track_buffer_mapping: torch.Tensor = (
torch.zeros(
(size, self.mamba_ping_pong_track_buffer_size),
dtype=torch.int32,
device=self.device,
)
)3.3 分配方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
def alloc(
self,
need_size: int, # 需要分配的请求数量
reqs: Optional[List["Req"]] # 请求对象列表
) -> Optional[List[int]]:
"""分配请求slot和对应的mamba cache
Args:
need_size: 需要分配的请求数
reqs: 请求对象列表,每个请求包含:
- req.mamba_pool_idx: 已分配的mamba slot (可能来自radix cache)
- req.mamba_ping_pong_track_buffer: ping-pong buffer (extra_buffer模式)
Returns:
分配的请求索引列表,分配失败返回None
分配流程:
1. 分配请求slot (基类)
2. 为每个请求分配mamba slot
3. 如果启用extra_buffer,分配ping-pong buffer
4. 更新映射表
注意:
- 如果请求已有mamba_pool_idx (来自radix cache),复用之
- Chunked prefill时不重新分配mamba cache
"""
assert reqs is not None
# ===== 1. 分配请求slot (基类) =====
select_index = super().alloc(need_size)
if select_index == None:
return None
# ===== 2. 为每个请求分配mamba slot =====
mamba_index = []
mamba_ping_pong_track_buffer_list = []
for req in reqs:
mid = None
# 检查是否已有mamba slot (来自radix cache)
if req.mamba_pool_idx is not None:
# Radix cache命中,复用已有slot
mid = req.mamba_pool_idx
else:
# 分配新的mamba slot
mid = self.mamba_pool.alloc(1)
assert (
mid is not None
), (
f"Not enough space for mamba cache, "
f"try to increase --mamba-full-memory-ratio or --max-mamba-cache-size. "
f"{mid=}, {self.mamba_pool.size=}, "
f"{self.mamba_pool.available_size()=}, {len(reqs)=}"
)
mid = mid[0] # 从tensor转为int
req.mamba_pool_idx = mid # 记录到请求对象
mamba_index.append(mid)
# ===== 3. 分配ping-pong buffer (extra_buffer模式) =====
if self.enable_mamba_extra_buffer:
if req.mamba_ping_pong_track_buffer is None:
# 首次分配ping-pong buffer
req.mamba_ping_pong_track_buffer = self.mamba_pool.alloc(
self.mamba_ping_pong_track_buffer_size # 2或1
)
assert (
req.mamba_ping_pong_track_buffer is not None
), "Not enough space for mamba ping pong idx, try to increase --mamba-full-memory-ratio."
req.mamba_next_track_idx = 0 # 初始化ping-pong索引
# 记录ping-pong buffer索引
mamba_ping_pong_track_buffer_list.append(
req.mamba_ping_pong_track_buffer.tolist()
)
# ===== 4. 验证分配成功 =====
assert len(select_index) == len(
mamba_index
), f"Not enough space for mamba cache, try to increase --mamba-full-memory-ratio or --max-mamba-cache-size."
if self.enable_mamba_extra_buffer:
assert len(select_index) == len(
mamba_ping_pong_track_buffer_list
), f"Not enough space for mamba ping pong idx, try to increase --mamba-full-memory-ratio."
# ===== 5. 更新映射表 =====
# req_index_to_mamba_index_mapping[request_idx] = mamba_slot_idx
self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
mamba_index, dtype=torch.int32, device=self.device
)
# req_index_to_mamba_ping_pong_track_buffer_mapping[request_idx] = [slot1, slot2]
if self.enable_mamba_extra_buffer:
self.req_index_to_mamba_ping_pong_track_buffer_mapping[select_index] = (
torch.tensor(
mamba_ping_pong_track_buffer_list,
dtype=torch.int32,
device=self.device,
)
)
return select_index3.4 释放方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def free(
self,
free_index: Union[int, List[int]], # 要释放的请求索引
free_mamba_cache: bool = True, # 是否释放mamba cache
mamba_ping_pong_track_buffer_to_keep: Optional[int] = None, # 保留哪个ping-pong slot
):
"""释放请求slot和对应的mamba cache
Args:
free_index: 请求索引 (单个或列表)
free_mamba_cache: 是否释放mamba cache (chunked prefill时为False)
mamba_ping_pong_track_buffer_to_keep: 保留的ping-pong slot索引 (0或1)
- 用于prefix cache保存时保留一个ping-pong slot
- None表示释放所有ping-pong slot
释放策略:
- 基础请求slot: 总是释放
- Mamba工作slot: 可选释放 (chunked prefill时保留)
- Ping-pong buffer: 可选保留一个slot
示例:
>>> # 释放所有cache
>>> pool.free(req_idx, free_mamba_cache=True)
>>>
>>> # Chunked prefill: 保留mamba cache
>>> pool.free(req_idx, free_mamba_cache=False)
>>>
>>> # Prefix cache: 保留ping-pong slot 0
>>> pool.free(req_idx, free_mamba_cache=True, mamba_ping_pong_track_buffer_to_keep=0)
"""
if isinstance(free_index, (int,)):
free_index = [free_index]
# ===== 1. 释放基础请求slot =====
super().free(free_index)
# ===== 2. 释放mamba cache (可选) =====
if free_mamba_cache:
# 获取要释放的mamba slot索引
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
self.mamba_pool.free(mamba_index)
# ===== 3. 释放ping-pong buffer (extra_buffer模式) =====
if self.enable_mamba_extra_buffer:
# 获取ping-pong buffer索引
mamba_ping_pong_track_buffer_to_free = (
self.req_index_to_mamba_ping_pong_track_buffer_mapping[
free_index
].squeeze(0)
)
# 如果指定要保留某个slot
if mamba_ping_pong_track_buffer_to_keep is not None:
assert mamba_ping_pong_track_buffer_to_keep in [0, 1], (
f"mamba_ping_pong_track_buffer_to_keep must be 0 or 1, "
f"{mamba_ping_pong_track_buffer_to_keep=}"
)
# 创建所有slot的索引
idx_to_free = list(range(self.mamba_ping_pong_track_buffer_size))
# 移除要保留的slot
idx_to_free.remove(mamba_ping_pong_track_buffer_to_keep)
# 获取要释放的slot (另一个)
mamba_ping_pong_track_buffer_to_free = (
mamba_ping_pong_track_buffer_to_free[idx_to_free]
)
# 释放ping-pong buffer (可能只释放一个)
self.mamba_pool.free(mamba_ping_pong_track_buffer_to_free)
def clear(self):
"""清空所有slot (重置池状态)"""
logger.info("Reset HybridReqToTokenPool")
# 清空基类
super().clear()
# 清空mamba pool
self.mamba_pool.clear()
# 清零映射表
self.req_index_to_mamba_index_mapping.zero_()
if self.enable_mamba_extra_buffer:
self.req_index_to_mamba_ping_pong_track_buffer_mapping.zero_()3.5 辅助方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
"""获取请求索引对应的mamba slot索引
Args:
req_indices: 请求索引tensor
Returns:
mamba slot索引tensor
示例:
>>> req_indices = torch.tensor([5, 10, 15])
>>> mamba_indices = pool.get_mamba_indices(req_indices)
>>> # mamba_indices: [3, 7, 12]
"""
return self.req_index_to_mamba_index_mapping[req_indices]
def mamba2_layer_cache(self, layer_id: int):
"""获取指定层的mamba cache
Args:
layer_id: 层ID (全局层索引)
Returns:
指定层的State对象
断言:
layer_id必须在mamba_map中 (即必须是Mamba层)
"""
assert layer_id in self.mamba_map
# 将全局层ID转换为连续索引
return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
"""获取所有层的投机采样mamba状态
Returns:
SpeculativeState对象
"""
return self.mamba_pool.get_speculative_mamba2_params_all_layers()
def get_mamba_ping_pong_other_idx(self, mamba_next_track_idx: int) -> int:
"""获取ping-pong buffer的另一个索引
Args:
mamba_next_track_idx: 当前ping-pong索引 (0或1)
Returns:
另一个索引 (1或0)
示例:
>>> pool.get_mamba_ping_pong_other_idx(0) # 返回 1
>>> pool.get_mamba_ping_pong_other_idx(1) # 返回 0
注意:
- 投机采样模式下 (buffer_size=1),返回自身
"""
if self.mamba_ping_pong_track_buffer_size == 2:
return 1 - mamba_next_track_idx
else:
return mamba_next_track_idx4. 内存分配流程
4.1 整体初始化流程
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
1. 计算内存大小
└─> model_runner_kv_cache_mixin.py:179-192
计算公式: mamba_memory = total * ratio / (1 + ratio)
2. 创建HybridReqToTokenPool
└─> new HybridReqToTokenPool(
size=max_total_tokens,
mamba_size=max_mamba_cache_size,
mamba_spec_state_size=speculative_num_draft_tokens,
...
)
3. 初始化MambaPool
└─> _init_mamba_pool()
├─> 创建MambaPool实例
├─> 分配conv_state
├─> 分配temporal_state
└─> 分配intermediate状态 (投机采样)
4. 初始化映射表
├─> req_index_to_mamba_index_mapping
└─> req_index_to_mamba_ping_pong_track_buffer_mapping (extra_buffer)4.2 请求分配流程
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
请求到达
│
├─> 检查是否有radix cache
│ └─> 有: req.mamba_pool_idx 已设置
│ └─> 无: req.mamba_pool_idx = None
│
├─> 调用 pool.alloc(need_size, reqs)
│
├─> 分配请求slot (基类)
│ └─> select_index = super().alloc(need_size)
│
├─> 为每个请求分配mamba slot
│ ├─> 如果 req.mamba_pool_idx is None
│ │ └─> mid = mamba_pool.alloc(1)
│ │ └─> 从free_slots取出一个slot
│ │ └─> 清零slot
│ │ └─> req.mamba_pool_idx = mid
│ └─> 否则复用已有slot
│
├─> 如果启用extra_buffer
│ └─> 如果 req.mamba_ping_pong_track_buffer is None
│ └─> 分配2个slot: [slot_ping, slot_pong]
│ └─> req.mamba_next_track_idx = 0
│
└─> 更新映射表
├─> req_index_to_mamba_index_mapping[req_idx] = mamba_slot
└─> req_index_to_mamba_ping_pong_...[req_idx] = [slot1, slot2]4.3 请求释放流程
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
请求完成
│
├─> 调用 pool.free(req_idx, free_mamba_cache, keep_ping_pong)
│
├─> 释放请求slot (基类)
│ └─> super().free(req_idx)
│
├─> 如果 free_mamba_cache = True
│ ├─> 获取mamba slot索引
│ └─> mamba_pool.free(mamba_slot)
│ └─> 将slot加回free_slots
│
└─> 如果启用extra_buffer
├─> 如果 keep_ping_pong is None
│ └─> 释放所有ping-pong slot
└─> 否则
└─> 只释放另一个slot,保留keep_ping_pong指定的slot 评论
匿名评论隐私政策




