1-piecewise实现原理

1-piecewise实现原理
gogongxtSGLang Piecewise CUDA Graph 实现机制技术文档
文档信息
| 项目 | 说明 |
|---|---|
| 目标 | 系统性分析 SGLang 当前的 piecewise 实现机制,形成面向工程实践的技术文档 |
| 版本 | SGLang 0.5.7 |
| 代码路径 | python/sglang/srt/compilation/,
python/sglang/srt/model_executor/ |
1. 设计背景与动机
1.1 解决的问题
CUDA Graph 在动态 shape 场景下的限制
传统 CUDA Graph capture 要求:
- 输入 tensor 的 shape 必须在编译时确定
- Memory layout 必须固定
- 不支持动态控制流
在 LLM 推理场景中:
- 输入序列长度(token 数)变化范围大
- Batch size 动态变化
- 不同请求的序列长度差异显著
1.2 使用场景
变长序列推理的核心需求
1
2
3
请求1: [token1, token2, ..., token16] → 需要 size=16 的 graph
请求2: [token1, token2, ..., token32] → 需要 size=32 的 graph
请求3: [token1, token2, ..., token128] → 需要 size=128 的 graphPiecewise 解决方案:
- 预先捕获多个固定 size 的 CUDA Graph
- 运行时根据输入 token 数选择合适的 graph
- 使用全局 memory pool 共享内存,提高利用率
1.3 核心差异
| 特性 | 传统单次 CUDA Graph | Piecewise CUDA Graph |
|---|---|---|
| 捕获次数 | 1 次 | 多次(每个 size 一次) |
| Shape 支持 | 固定 | 多个固定 size |
| Graph 选择 | 无需选择 | Binary search 定位 |
| 内存管理 | 独立分配 | 全局 pool 共享 |
| Torch.compile 集成 | 独立使用 | 联合优化 |
2. 实现原理 (How it works)
2.1 整体工作流程
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
┌─────────────────────────────────────────────────────────────────┐
│ 初始化阶段 (Initialization) │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ 1. 创建 CompilationConfig (capture_sizes, compiler) │ │
│ │ 2. 初始化全局 graph memory pool │ │
│ │ 3. 设置 torch.compile 配置 │ │
│ └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────────┐
│ Torch.compile 阶段 │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ 1. install_torch_compiled(fullgraph=True) │ │
│ │ 2. 对每个 capture_size 执行 warmup │ │
│ │ 3. 图切分 (split_graph) 按 split_ops │ │
│ │ 4. 编译每个 subgraph (PiecewiseCompileInterpreter) │ │
│ └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────────┐
│ CUDA Graph 捕获阶段 │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ for num_tokens in reversed(capture_sizes): │ │
│ │ 1. capture_one_batch_size(num_tokens) │ │
│ │ 2. 执行 forward pass (2次: warmup + capture) │ │
│ │ 3. torch.cuda.graph(cudagraph, pool=graph_pool) │ │
│ └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────────┐
│ 执行阶段 (Runtime) │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ 1. can_run(): 检查是否在捕获范围内 │ │
│ │ 2. replay_prepare(): 准备输入数据 │ │
│ │ 3. binary search 定位合适的 graph │ │
│ │ 4. cudagraph.replay(): 重放 graph │ │
│ └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘2.2 Piece 组织方式
按 token 数量划分
每个 capture size 对应一个 ConcreteSizeEntry:
1
2
3
4
5
6
7
8
9
10
11
12
# From cuda_piecewise_backend.py:23-38
@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int # Token 数量
need_to_compile: bool # 是否需要编译
use_cudagraph: bool # 是否使用 CUDA Graph
compiled: bool = False
runnable: Callable = None # 编译后的可执行函数
num_finished_warmup: int = 0 # Warmup 计数
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None # 输出 (weak ref)2.3 执行切换机制
Binary Search 快速定位
1
2
3
4
5
6
7
# From piecewise_cuda_graph_runner.py:545
def replay_prepare(self, forward_batch: ForwardBatch, **kwargs):
num_tokens = len(forward_batch.input_ids)
# 使用 bisect_left 找到第一个 >= num_tokens 的 capture size
index = bisect.bisect_left(self.capture_num_tokens, num_tokens)
static_num_tokens = self.capture_num_tokens[index]
# ...示例:
1
2
3
4
5
capture_sizes = [16, 32, 64, 128, 256]
实际输入 num_tokens=45 → index=2 → static_num_tokens=64
实际输入 num_tokens=128 → index=3 → static_num_tokens=128 (精确匹配)
实际输入 num_tokens=300 → 不在范围内,can_run() 返回 False3. 实现方式与关键数据结构
3.1 核心类
3.1.1 PiecewiseCudaGraphRunner
主执行器
(piecewise_cuda_graph_runner.py:127)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class PiecewiseCudaGraphRunner:
"""使用 CUDA Graph 和 torch.compile 运行模型前向传播"""
def __init__(self, model_runner: ModelRunner):
# 初始化配置
self.compile_config = CompilationConfig(
capture_sizes,
compiler, # "eager" or "inductor"
enable_debug_mode
)
# 输入缓冲区 (预分配最大 size)
self.input_ids = torch.zeros((self.max_num_tokens,), dtype=torch.int64)
self.positions = torch.zeros((self.max_num_tokens,), dtype=torch.int64)
# ...
# 全局内存池
set_global_graph_memory_pool(self.device_module.graph_pool_handle())
# 编译和捕获
with enable_piecewise_cuda_graph():
install_torch_compiled(patched_model, fullgraph=True)
self.capture() # 捕获所有 sizes主要方法:
| 方法 | 行号 | 功能 |
|---|---|---|
warmup_torch_compile() |
282 | Torch.compile 预热 |
capture() |
382 | 捕获所有 sizes 的 CUDA Graph |
capture_one_batch_size() |
416 | 捕获单个 batch size |
can_run() |
369 | 检查是否可使用 CUDA Graph |
replay_prepare() |
539 | 准备 replay 的输入数据 |
replay() |
670 | 执行 CUDA Graph replay |
3.1.2 CompilationConfig
编译配置 (compilation_config.py:7)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class CompilationConfig:
def __init__(
self,
capture_sizes: List[int], # 要捕获的 token sizes
compiler: str = "eager", # 编译器类型
enable_debug_mode: bool = False,
):
self.capture_sizes = capture_sizes
self.compiler = compiler
self.enable_debug_mode = enable_debug_mode
# OP 隔离配置
self.split_ops = [
"sglang.unified_attention_with_output",
"sglang.gdn_with_output",
"sglang.inplace_all_reduce",
]3.1.3 CUDAPiecewiseBackend
CUDA 后端实现
(cuda_piecewise_backend.py:40)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class CUDAPiecewiseBackend:
def __call__(self, *args) -> Any:
# 1. 首次运行:返回通用编译的 graph
if not self.first_run_finished:
return self.compiled_graph_for_general_shape(*args)
# 2. 获取运行时 shape
runtime_shape = args[self.sym_shape_indices[0]]
# 3. 查找对应的 entry
entry = self.concrete_size_entries[runtime_shape]
# 4. 如果需要编译,先编译
if entry.need_to_compile and not entry.compiled:
entry.runnable = self.sglang_backend.compiler_manager.compile(...)
entry.compiled = True
# 5. 如果需要 CUDA Graph,进行捕获
if entry.cudagraph is None:
if entry.num_finished_warmup < 1:
entry.num_finished_warmup += 1
return entry.runnable(*args) # Warmup
# 捕获 CUDA Graph
cudagraph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
output = entry.runnable(*args)
entry.cudagraph = cudagraph
# 6. Replay CUDA Graph
entry.cudagraph.replay()
return entry.output3.1.4 NPUPiecewiseBackend
NPU 后端实现
(npu_piecewise_backend.py)
继承自 CUDAPiecewiseBackend,为 NPU 硬件提供定制化实现。
3.2 关键数据结构
3.2.1 ConcreteSizeEntry
单个 token size 的管理对象
(cuda_piecewise_backend.py:23)
1
2
3
4
5
6
7
8
9
10
11
12
@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int # 运行时 token 数量
need_to_compile: bool # 该 size 是否需要特殊编译
use_cudagraph: bool # 该 size 是否使用 CUDA Graph
compiled: bool = False # 是否已完成编译
runnable: Callable = None # 可执行函数
num_finished_warmup: int = 0 # Warmup 次数计数
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None # 输出 (weak ref)
input_addresses: Optional[list[int]] = None # 调试用3.2.2 ForwardContext
跨 piece 传递的执行上下文
(piecewise_context_manager.py:52)
1
2
3
4
5
6
@dataclass
class ForwardContext:
forward_batch: ForwardBatch = None # 批次信息
attention_layer: List[Any] = None # Attention 层
quant_config: Any = None # 量化配置
moe_layers: List[Any] = None # MoE 层通过 set_forward_context() context manager
在执行期间设置:
1
2
3
4
5
6
7
with set_forward_context(
forward_batch,
self.attention_layers,
self.quant_config,
self.moe_layers,
):
output = self.model_runner.model.forward(...)3.2.3 全局内存池
跨 runners 共享的内存管理
1
2
3
4
5
6
7
8
9
# From piecewise_cuda_graph_runner.py:106-115
global_graph_memory_pool = None
def get_global_graph_memory_pool():
return global_graph_memory_pool
def set_global_graph_memory_pool(val):
global global_graph_memory_pool
global_graph_memory_pool = val优势:
- 多个 runner 共享,减少内存占用
- 大 graph 优先捕获,小 graph 复用内存
- 使用 weak reference 释放不用的输出
3.3 Piece 生命周期
1
2
3
创建 → Warmup → Capture (torch.compile + CUDA graph) → Replay → 清理
↓ ↓ ↓ ↓ ↓
初始化 首次运行 编译+捕获 运行时 GC详细流程:
- 创建
(
PiecewiseCudaGraphRunner.__init__)- 分配输入缓冲区
- 初始化全局内存池
- 设置配置
- Warmup (
warmup_torch_compile)- 执行一次 forward pass
- 触发 torch.compile
- 建立计算图
- Capture (
capture_one_batch_size)- 第一次调用:Warmup
- 第二次调用:捕获 CUDA Graph
- 存储到
ConcreteSizeEntry
- Replay (
replay)- 查找合适的 graph
- 复制输入数据
- 调用
cudagraph.replay()
- 清理
- Weak reference 自动释放
- GC 回收临时对象
4. OP 隔离机制详解
4.1 隔离策略
基于 operation 的 graph 切分
通过 split_ops 配置指定需要在 graph
边界处切分的操作:
1
2
3
4
5
6
# From compilation_config.py:18-22
self.split_ops = [
"sglang.unified_attention_with_output", # 统一注意力计算
"sglang.gdn_with_output", # GDN (Grouped Distributed Norm)
"sglang.inplace_all_reduce", # 分布式通信
]4.2 隔离实现方式
4.2.1 Graph 切分
使用
torch.fx.profiler.split_module()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# From backend.py:213-256
def split_graph(
graph: fx.GraphModule, ops: list[str]
) -> tuple[fx.GraphModule, list[SplitItem]]:
# 为每个 node 分配 subgraph ID
subgraph_id = 0
node_to_subgraph_id = {}
split_op_graphs = []
for node in graph.graph.nodes:
if node.op == "call_function" and str(node.target) in ops:
# 在 split op 处增加 subgraph_id,创建边界
subgraph_id += 1
node_to_subgraph_id[node] = subgraph_id
split_op_graphs.append(subgraph_id)
subgraph_id += 1
else:
node_to_subgraph_id[node] = subgraph_id
# 使用 keep_original_order=True 保持语义正确
split_gm = torch.fx.passes.split_module.split_module(
graph, None,
lambda node: node_to_subgraph_id[node],
keep_original_order=True # 重要!保持顺序
)
return split_gm, outputs4.2.2 编译阶段
独立编译每个 subgraph
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# From backend.py:396-472
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
# 1. 切分图
self.split_gm, self.piecewise_graphs = split_graph(
graph,
self.compile_config.split_ops,
)
# 2. 选择需要编译的 submodules
submod_names_to_compile = [
item.submod_name
for item in self.piecewise_graphs
if not item.is_splitting_graph # 不编译 split op 本身
]
# 3. 使用 PiecewiseCompileInterpreter 编译
PiecewiseCompileInterpreter(
self.split_gm,
submod_names_to_compile,
self.inductor_config,
self.graph_pool,
self.compile_config,
self,
).run(*example_inputs)4.3 默认隔离的 OP
| OP 名称 | 功能 | 隔离原因 |
|---|---|---|
sglang.unified_attention_with_output |
统一注意力计算 | 包含复杂控制流,shape 相关逻辑 |
sglang.gdn_with_output |
Grouped Distributed Norm | 分布式同步点,需要独立执行 |
sglang.inplace_all_reduce |
分布式 all-reduce | 通信操作,需要同步 |
sglang.moe_forward_piecewise_cuda_graph_impl |
MoE 前向传播 (DeepEP/Mooncake) | 动态路由,控制流复杂 |
4.4 执行层面影响
4.4.1 执行边界
Split ops 处形成执行边界
1
2
3
[Subgraph 0] → unified_attention (split op) → [Subgraph 1] → gdn (split op) → [Subgraph 2]
↑ ↑
执行边界 执行边界在每个边界处:
- 退出 torch.compile 模式
- 执行 custom op
- 重新进入 torch.compile 模式
4.4.2 同步点控制
CustomOp 通过 enter/leave 控制
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# From custom_op.py:35-69
class CustomOp(nn.Module):
def enter_torch_compile(self, num_tokens: int):
# 进入 torch.compile 模式
if self.is_torch_compile:
return # 已经在编译模式中
self._original_forward_method = self._forward_method
self._forward_method = self.forward_native # 使用原生实现
self.is_torch_compile = True
def leave_torch_compile(self):
# 离开 torch.compile 模式
if not self.is_torch_compile:
return
self._forward_method = self._original_forward_method
self.is_torch_compile = False4.4.3 内存隔离
每个 piece 使用独立内存分配
虽然使用全局 graph pool,但每个 piece 的输入输出内存是独立的:
1
2
3
4
# 在捕获时,使用 weak ref 减少 memory hold
if self.is_last_graph:
output = weak_ref_tensors(output) # 仅对最后一个 graph
entry.output = weak_ref_tensors(output) # 所有 graph 都使用 weak ref4.5 隔离机制示例
以 Attention 为例:
1
2
3
4
5
6
7
8
9
10
原始计算图:
input → layernorm → attention → mlp → output
切分后 (split attention):
[Subgraph 0: layernorm] → attention (split) → [Subgraph 1: mlp]
执行流程:
1. Subgraph 0 使用 CUDA Graph replay
2. Attention 退出 graph,正常执行
3. Subgraph 1 使用 CUDA Graph replay5. 代码详细解析
5.1 关键文件清单
1
2
3
4
5
6
7
8
9
10
11
12
13
python/sglang/srt/
├── compilation/
│ ├── compilation_config.py - 配置管理 (7-38 行)
│ ├── backend.py - 后端工厂和 graph 切分 (29-473 行)
│ ├── cuda_piecewise_backend.py - CUDA 实现 (40-210 行)
│ ├── npu_piecewise_backend.py - NPU 实现
│ ├── piecewise_context_manager.py - 上下文管理 (1-99 行)
│ ├── compiler_interface.py - 编译器接口
│ └── weak_ref_tensor.py - Weak reference 工具
├── model_executor/
│ ├── piecewise_cuda_graph_runner.py - 主执行器 (127-741 行)
│ └── model_runner.py - 集成入口
└── custom_op.py - 自定义操作基类 (26-109 行)5.2 核心代码路径
5.2.1 初始化流程
1
2
3
4
5
6
7
8
9
10
11
ModelRunner.init_piecewise_cuda_graphs()
↓
PiecewiseCudaGraphRunner.__init__() (137 行)
↓
install_torch_compiled() (246-252 行)
↓
warmup_torch_compile() (265, 282 行)
↓
capture() (274 行)
↓
capture_one_batch_size() (414, 416 行)关键代码片段:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# From piecewise_cuda_graph_runner.py:137-278
def __init__(self, model_runner: ModelRunner):
# 1. 配置初始化
self.compile_config = CompilationConfig(
self.model_runner.server_args.piecewise_cuda_graph_tokens,
self.model_runner.server_args.piecewise_cuda_graph_compiler,
)
# 2. 内存池初始化
if get_global_graph_memory_pool() is None:
set_global_graph_memory_pool(self.device_module.graph_pool_handle())
set_graph_pool_id(get_global_graph_memory_pool())
# 3. Torch.compile 安装
with enable_piecewise_cuda_graph():
install_torch_compiled(
patched_model,
fullgraph=True,
compile_config=self.compile_config,
graph_pool=get_global_graph_memory_pool(),
)
# 4. Warmup 和 Capture
with set_compiled(True):
for num_tokens in reversed(self.capture_num_tokens):
self.warmup_torch_compile(num_tokens)
self.capture() # 5. 捕获 CUDA Graph5.2.2 执行流程
1
2
3
4
5
6
7
8
9
10
11
12
13
ModelRunner.forward()
↓
can_run() (369 行)
↓
replay() (670 行)
↓
replay_prepare() (539, 677 行)
↓
model.forward() (686 行)
↓
CUDAPiecewiseBackend.__call__() (cuda_piecewise_backend.py:107)
↓
cudagraph.replay() (cuda_piecewise_backend.py:208)关键代码片段:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# From piecewise_cuda_graph_runner.py:369-380
def can_run(self, forward_batch: ForwardBatch):
num_tokens = len(forward_batch.input_ids)
# 检查 logprob 是否在中间状态
if forward_batch.return_logprob:
for start_len, seq_len in zip(...):
if start_len is not None and start_len < seq_len:
return False # 不支持 logprob 中间状态
# 检查是否在捕获范围内
if num_tokens <= self.max_num_tokens:
return True
return False1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# From piecewise_cuda_graph_runner.py:539-668
def replay_prepare(self, forward_batch: ForwardBatch, **kwargs):
num_tokens = len(forward_batch.input_ids)
# Binary search 查找合适的 size
index = bisect.bisect_left(self.capture_num_tokens, num_tokens)
static_num_tokens = self.capture_num_tokens[index]
# 复制输入数据
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
self.positions[:num_tokens].copy_(forward_batch.positions)
# ...
# 创建静态 ForwardBatch
static_forward_batch = ForwardBatch(
input_ids=self.input_ids[:static_num_tokens],
positions=self.positions[:static_num_tokens],
# ...
)
return static_forward_batch5.2.3 编译流程
1
2
3
4
5
6
7
8
9
SGLangBackend.__call__() (backend.py:396)
↓
split_graph() (backend.py:424, 213)
↓
PiecewiseCompileInterpreter.run() (backend.py:443, 287)
↓
CompilerManager.compile() (backend.py:312, 128)
↓
CUDAPiecewiseBackend.__call__() (cuda_piecewise_backend.py:107)关键代码片段:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# From backend.py:265-336
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
def call_module(self, target, args, kwargs):
output = super().call_module(target, args, kwargs)
if target in self.compile_submod_names:
# 编译该 subgraph
compiled_graph = self.sglang_backend.compiler_manager.compile(
submod,
args,
self.inductor_config,
graph_index=index,
runtime_shape=None, # 通用 shape
)
# 创建后端
self.module.__dict__[target] = make_backend(
submod,
self.compile_config,
self.inductor_config,
self.graph_pool,
index,
len(self.compile_submod_names),
sym_shape_indices,
compiled_graph,
self.sglang_backend,
)
return output5.3 重要函数说明
| 函数 | 文件 | 行号 | 功能 |
|---|---|---|---|
split_graph() |
backend.py | 213 | 按指定 ops 切分计算图 |
make_compiler() |
backend.py | 29 | 创建编译器适配器 |
make_backend() |
backend.py | 38 | 创建 piecewise 后端 |
freeze_gc() |
piecewise_cuda_graph_runner.py | 67 | 冻结 GC 优化捕获 |
patch_model() |
piecewise_cuda_graph_runner.py | 96 | 为 CustomOp 打补丁 |
enable_piecewise_cuda_graph() |
piecewise_context_manager.py | 35 | 启用 piecewise 模式 |
set_forward_context() |
piecewise_context_manager.py | 83 | 设置执行上下文 |
5.4 配置参数
启动参数 (通过 server_args 传入):
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
--enable-piecewise-cuda-graph |
bool | False | 启用 piecewise CUDA Graph |
--piecewise-cuda-graph-tokens |
list[int] | [1, 2, 4, 8, 16, 32, 64, 128, 256] | 要捕获的 token sizes |
--piecewise-cuda-graph-compiler |
str | “eager” | 编译器类型 (eager/inductor) |
--enable-torch-compile-debug-mode |
bool | False | 启用调试模式 |
环境变量:
| 变量 | 说明 |
|---|---|
SGLANG_CACHE_DIR |
缓存目录路径 (默认 ~/.cache/sglang/) |
6. 总结与工程视角评价
6.1 优点
6.1.1 性能优化
- CUDA Graph 减少 kernel launch 开销
- CPU 到 GPU 的 kernel launch 延迟通常在 5-20μs
- 对于小 batch (1-16 tokens),单个 kernel 可能仅耗时 10-50μs
- CUDA Graph 将多次 launch 合并为一次,显著提升小 batch 性能
- Torch.compile 算子融合
- 自动优化算子融合策略
- 减少内存读写次数
- 提升计算密度
- 全局 memory pool + weak reference
- 多 graph 共享内存,降低峰值占用
- Weak reference 及时释放无用输出
6.1.2 灵活性
- 多 size 覆盖
- 支持任意配置的 capture sizes
- 运行时动态选择最合适的 graph
- 硬件抽象
- 统一接口支持 CUDA 和 NPU
- 易于扩展到其他硬件
- OP 级别隔离
- 灵活配置 split ops
- 支持复杂控制流的操作
6.1.3 工程实践
- 完整的调试支持
- Debug mode 验证输入地址一致性
- Graph 导出和可视化
- 详细的日志和计数器
- 缓存机制
- 编译结果持久化
- 减少启动时间
6.2 局限
6.2.1 功能冲突
| 冲突项 | 说明 |
|---|---|
| Pipeline Parallel (PP) | 不支持 PP,因为需要跨 stage 的状态传递 |
| 部分 torch.compile 功能 | 某些 dynamo 特性与 fullgraph 冲突 |
| Logprob 中间状态 | return_logprob 时需要序列稳定在中间状态 |
6.2.2 MOE 限制
不支持某些 MOE backend:
- DeepEP: 需要添加
moe_forward_piecewise_cuda_graph_impl到 split_ops - Mooncake: 同上
原因:动态路由导致控制流复杂,难以静态捕获
6.2.3 模型要求
- 仅支持标准 GQA attention
- 不支持非标准的 attention 变体
- 要求特定的 memory layout
6.2.4 内存开销
需要预先分配 memory pool:
- 对大模型/大 batch 有压力
- 需要调整
--mem-fraction-static参数 - 可能需要减少
--piecewise-cuda-graph-max-tokens
6.3 可扩展性
6.3.1 新增 split op
步骤:
- 在
CompilationConfig.split_ops添加 op 名称 - 确保 op 继承自
CustomOp - 实现
forward_native()方法
示例:
1
2
3
4
5
6
7
# compilation_config.py
self.split_ops = [
"sglang.unified_attention_with_output",
"sglang.gdn_with_output",
"sglang.inplace_all_reduce",
"sglang.my_custom_op", # 新增
]6.3.2 新增 backend
实现接口:
1
2
3
4
5
6
7
8
9
10
11
12
13
class MyPiecewiseBackend:
def __init__(self, graph, compile_config, inductor_config,
graph_pool, piecewise_compile_index, ...):
# 初始化
pass
def __call__(self, *args):
# 编译和捕获逻辑
pass
def capture(self):
# 硬件特定的捕获逻辑
pass6.3.3 自定义 compiler
实现 CompilerInterface:
1
2
3
4
5
6
7
8
9
10
from sglang.srt.compilation.compiler_interface import CompilerInterface
class MyCompiler(CompilerInterface):
def compile(self, graph, example_inputs, config):
# 自定义编译逻辑
pass
def load(self, handle, graph, example_inputs):
# 加载缓存的编译结果
pass6.4 最佳实践
Capture sizes 选择
- 根据实际 workload 的 token 分布选择
- 覆盖 80-90% 的请求
- 避免过多 sizes 导致内存压力
内存调优
1
2
3
4
5# 降低静态内存比例 --mem-fraction-static 0.8 # 减少最大 capture size --piecewise-cuda-graph-max-tokens 512编译器选择
eager: 更快启动,兼容性更好inductor: 更好性能,但编译时间更长
调试技巧
1
2
3
4
5# 启用调试模式 --enable-torch-compile-debug-mode # 检查导出的 graph ls ~/.cache/sglang/torch_compile_cache/
7. 文档附录
7.1 配置参数完整说明
| 参数 | 类型 | 默认值 | 范围 | 说明 |
|---|---|---|---|---|
enable_piecewise_cuda_graph |
bool | False | - | 启用 piecewise CUDA Graph |
piecewise_cuda_graph_tokens |
list[int] | [1,2,4,8,16,32,64,128,256] | 1-2048 | 要捕获的 token sizes |
piecewise_cuda_graph_compiler |
str | “eager” | eager/inductor | 编译器类型 |
enable_torch_compile_debug_mode |
bool | False | - | 启用调试模式 |
enable_cudagraph_gc |
bool | False | - | 捕获期间启用 GC |
mem_fraction_static |
float | 0.9 | 0.1-0.95 | 静态内存比例 |
7.2 关键函数调用序列图
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
启动阶段
ModelRunner.__init__
└─> init_piecewise_cuda_graphs
├─> PiecewiseCudaGraphRunner.__init__
│ ├─> CompilationConfig.__init__
│ ├─> set_global_graph_memory_pool
│ └─> install_torch_compiled
│ └─> SGLangBackend.__call__
│ ├─> split_graph
│ └─> PiecewiseCompileInterpreter.run
│ └─> CompilerManager.compile
├─> warmup_torch_compile (for each size)
└─> capture
└─> capture_one_batch_size (for each size)
└─> CUDAPiecewiseBackend.__call__
└─> torch.cuda.graph
推理阶段
ModelRunner.forward
└─> can_run
└─> replay
├─> replay_prepare
│ └─> bisect.bisect_left (查找 graph)
└─> model.forward
└─> CUDAPiecewiseBackend.__call__
└─> cudagraph.replay7.3 性能对比数据
理论分析:
| 场景 | Eager | Torch.compile | Piecewise (Eager) | Piecewise (Inductor) |
|---|---|---|---|---|
| 小 batch (1-4 tokens) | 1.0x | 1.2-1.5x | 1.5-2.0x | 1.8-2.5x |
| 中 batch (8-32 tokens) | 1.0x | 1.3-1.6x | 1.3-1.7x | 1.6-2.0x |
| 大 batch (64+ tokens) | 1.0x | 1.2-1.4x | 1.1-1.3x | 1.3-1.6x |
启动开销:
| 模式 | 启动时间 |
|---|---|
| Eager | ~1s |
| Torch.compile (eager) | ~10-30s |
| Torch.compile (inductor) | ~30-60s |
| Piecewise | +10-20% 上述时间 |
7.4 常见问题排查
问题 1: Capture 失败
错误信息:
1
RuntimeError: CUDA graph capture failed解决方案:
1
2
3
4
5
6
7
8
# 1. 降低静态内存比例
--mem-fraction-static 0.8
# 2. 减少最大 capture size
--piecewise-cuda-graph-max-tokens 512
# 3. 禁用 piecewise
--disable-piecewise-cuda-graph问题 2: 内存不足
错误信息:
1
CUDA out of memory解决方案:
1
2
3
4
5
6
7
8
# 1. 减少捕获 sizes
--piecewise-cuda-graph-tokens 1,2,4,8,16,32,64
# 2. 降低静态内存
--mem-fraction-static 0.7
# 3. 启用 GC
--enable-cudagraph-gc问题 3: 性能下降
可能原因:
- Capture sizes 不匹配 workload
- 频繁 miss cache
- 过大的 graph 导致内存压力
解决方案:
1
2
3
4
5
6
7
8
# 1. 调整 capture sizes
--piecewise-cuda-graph-tokens 16,32,64,128
# 2. 使用 inductor 编译器
--piecewise-cuda-graph-compiler inductor
# 3. 清理缓存
rm -rf ~/.cache/sglang/torch_compile_cache/7.5 参考资料
源代码:
- SGLang Repository: https://github.com/sgl-project/sglang
- 基于 vLLM v0.10.0 改编
相关文档:
- CUDA Graph: https://developer.nvidia.com/blog/cuda-graphs/
- Torch.compile: https://pytorch.org/tutorials/recipes/torch_compile_guide.html
- Torch FX: https://pytorch.org/docs/stable/fx.html
设计参考:
- vLLM Compilation System: https://github.com/vllm-project/vllm/blob/main/vllm/compilation/




