1-piecewise实现原理

SGLang Piecewise CUDA Graph 实现机制技术文档

文档信息

项目 说明
目标 系统性分析 SGLang 当前的 piecewise 实现机制,形成面向工程实践的技术文档
版本 SGLang 0.5.7
代码路径 python/sglang/srt/compilation/, python/sglang/srt/model_executor/

1. 设计背景与动机

1.1 解决的问题

CUDA Graph 在动态 shape 场景下的限制

传统 CUDA Graph capture 要求:

  • 输入 tensor 的 shape 必须在编译时确定
  • Memory layout 必须固定
  • 不支持动态控制流

在 LLM 推理场景中:

  • 输入序列长度(token 数)变化范围大
  • Batch size 动态变化
  • 不同请求的序列长度差异显著

1.2 使用场景

变长序列推理的核心需求

1
2
3
请求1: [token1, token2, ..., token16]   → 需要 size=16 的 graph
请求2: [token1, token2, ..., token32]   → 需要 size=32 的 graph
请求3: [token1, token2, ..., token128]  → 需要 size=128 的 graph

Piecewise 解决方案:

  • 预先捕获多个固定 size 的 CUDA Graph
  • 运行时根据输入 token 数选择合适的 graph
  • 使用全局 memory pool 共享内存,提高利用率

1.3 核心差异

特性 传统单次 CUDA Graph Piecewise CUDA Graph
捕获次数 1 次 多次(每个 size 一次)
Shape 支持 固定 多个固定 size
Graph 选择 无需选择 Binary search 定位
内存管理 独立分配 全局 pool 共享
Torch.compile 集成 独立使用 联合优化

2. 实现原理 (How it works)

2.1 整体工作流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
┌─────────────────────────────────────────────────────────────────┐
│                      初始化阶段 (Initialization)                  │
│  ┌────────────────────────────────────────────────────────────┐ │
│  │  1. 创建 CompilationConfig (capture_sizes, compiler)        │ │
│  │  2. 初始化全局 graph memory pool                            │ │
│  │  3. 设置 torch.compile 配置                                │ │
│  └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────┐
│                    Torch.compile 阶段                           │
│  ┌────────────────────────────────────────────────────────────┐ │
│  │  1. install_torch_compiled(fullgraph=True)                 │ │
│  │  2. 对每个 capture_size 执行 warmup                         │ │
│  │  3. 图切分 (split_graph) 按 split_ops                       │ │
│  │  4. 编译每个 subgraph (PiecewiseCompileInterpreter)        │ │
│  └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────┐
│                    CUDA Graph 捕获阶段                          │
│  ┌────────────────────────────────────────────────────────────┐ │
│  │  for num_tokens in reversed(capture_sizes):                 │ │
│  │    1. capture_one_batch_size(num_tokens)                   │ │
│  │    2. 执行 forward pass (2次: warmup + capture)            │ │
│  │    3. torch.cuda.graph(cudagraph, pool=graph_pool)        │ │
│  └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────┐
│                    执行阶段 (Runtime)                           │
│  ┌────────────────────────────────────────────────────────────┐ │
│  │  1. can_run(): 检查是否在捕获范围内                         │ │
│  │  2. replay_prepare(): 准备输入数据                          │ │
│  │  3. binary search 定位合适的 graph                          │ │
│  │  4. cudagraph.replay(): 重放 graph                         │ │
│  └────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘

2.2 Piece 组织方式

按 token 数量划分

每个 capture size 对应一个 ConcreteSizeEntry:

1
2
3
4
5
6
7
8
9
10
11
12
# From cuda_piecewise_backend.py:23-38
@dataclasses.dataclass
class ConcreteSizeEntry:
    runtime_shape: int              # Token 数量
    need_to_compile: bool           # 是否需要编译
    use_cudagraph: bool             # 是否使用 CUDA Graph

    compiled: bool = False
    runnable: Callable = None       # 编译后的可执行函数
    num_finished_warmup: int = 0    # Warmup 计数
    cudagraph: Optional[torch.cuda.CUDAGraph] = None
    output: Optional[Any] = None    # 输出 (weak ref)

2.3 执行切换机制

Binary Search 快速定位

1
2
3
4
5
6
7
# From piecewise_cuda_graph_runner.py:545
def replay_prepare(self, forward_batch: ForwardBatch, **kwargs):
    num_tokens = len(forward_batch.input_ids)
    # 使用 bisect_left 找到第一个 >= num_tokens 的 capture size
    index = bisect.bisect_left(self.capture_num_tokens, num_tokens)
    static_num_tokens = self.capture_num_tokens[index]
    # ...

示例:

1
2
3
4
5
capture_sizes = [16, 32, 64, 128, 256]

实际输入 num_tokens=45  → index=2 → static_num_tokens=64
实际输入 num_tokens=128 → index=3 → static_num_tokens=128 (精确匹配)
实际输入 num_tokens=300 → 不在范围内,can_run() 返回 False

3. 实现方式与关键数据结构

3.1 核心类

3.1.1 PiecewiseCudaGraphRunner

主执行器 (piecewise_cuda_graph_runner.py:127)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class PiecewiseCudaGraphRunner:
    """使用 CUDA Graph 和 torch.compile 运行模型前向传播"""

    def __init__(self, model_runner: ModelRunner):
        # 初始化配置
        self.compile_config = CompilationConfig(
            capture_sizes,
            compiler,  # "eager" or "inductor"
            enable_debug_mode
        )

        # 输入缓冲区 (预分配最大 size)
        self.input_ids = torch.zeros((self.max_num_tokens,), dtype=torch.int64)
        self.positions = torch.zeros((self.max_num_tokens,), dtype=torch.int64)
        # ...

        # 全局内存池
        set_global_graph_memory_pool(self.device_module.graph_pool_handle())

        # 编译和捕获
        with enable_piecewise_cuda_graph():
            install_torch_compiled(patched_model, fullgraph=True)
            self.capture()  # 捕获所有 sizes

主要方法:

方法 行号 功能
warmup_torch_compile() 282 Torch.compile 预热
capture() 382 捕获所有 sizes 的 CUDA Graph
capture_one_batch_size() 416 捕获单个 batch size
can_run() 369 检查是否可使用 CUDA Graph
replay_prepare() 539 准备 replay 的输入数据
replay() 670 执行 CUDA Graph replay

3.1.2 CompilationConfig

编译配置 (compilation_config.py:7)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class CompilationConfig:
    def __init__(
        self,
        capture_sizes: List[int],      # 要捕获的 token sizes
        compiler: str = "eager",        # 编译器类型
        enable_debug_mode: bool = False,
    ):
        self.capture_sizes = capture_sizes
        self.compiler = compiler
        self.enable_debug_mode = enable_debug_mode

        # OP 隔离配置
        self.split_ops = [
            "sglang.unified_attention_with_output",
            "sglang.gdn_with_output",
            "sglang.inplace_all_reduce",
        ]

3.1.3 CUDAPiecewiseBackend

CUDA 后端实现 (cuda_piecewise_backend.py:40)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class CUDAPiecewiseBackend:
    def __call__(self, *args) -> Any:
        # 1. 首次运行:返回通用编译的 graph
        if not self.first_run_finished:
            return self.compiled_graph_for_general_shape(*args)

        # 2. 获取运行时 shape
        runtime_shape = args[self.sym_shape_indices[0]]

        # 3. 查找对应的 entry
        entry = self.concrete_size_entries[runtime_shape]

        # 4. 如果需要编译,先编译
        if entry.need_to_compile and not entry.compiled:
            entry.runnable = self.sglang_backend.compiler_manager.compile(...)
            entry.compiled = True

        # 5. 如果需要 CUDA Graph,进行捕获
        if entry.cudagraph is None:
            if entry.num_finished_warmup < 1:
                entry.num_finished_warmup += 1
                return entry.runnable(*args)  # Warmup

            # 捕获 CUDA Graph
            cudagraph = torch.cuda.CUDAGraph()
            with torch.cuda.graph(cudagraph, pool=self.graph_pool):
                output = entry.runnable(*args)
            entry.cudagraph = cudagraph

        # 6. Replay CUDA Graph
        entry.cudagraph.replay()
        return entry.output

3.1.4 NPUPiecewiseBackend

NPU 后端实现 (npu_piecewise_backend.py)

继承自 CUDAPiecewiseBackend,为 NPU 硬件提供定制化实现。

3.2 关键数据结构

3.2.1 ConcreteSizeEntry

单个 token size 的管理对象 (cuda_piecewise_backend.py:23)

1
2
3
4
5
6
7
8
9
10
11
12
@dataclasses.dataclass
class ConcreteSizeEntry:
    runtime_shape: int              # 运行时 token 数量
    need_to_compile: bool           # 该 size 是否需要特殊编译
    use_cudagraph: bool             # 该 size 是否使用 CUDA Graph

    compiled: bool = False          # 是否已完成编译
    runnable: Callable = None       # 可执行函数
    num_finished_warmup: int = 0    # Warmup 次数计数
    cudagraph: Optional[torch.cuda.CUDAGraph] = None
    output: Optional[Any] = None    # 输出 (weak ref)
    input_addresses: Optional[list[int]] = None  # 调试用

3.2.2 ForwardContext

跨 piece 传递的执行上下文 (piecewise_context_manager.py:52)

1
2
3
4
5
6
@dataclass
class ForwardContext:
    forward_batch: ForwardBatch = None   # 批次信息
    attention_layer: List[Any] = None    # Attention 层
    quant_config: Any = None             # 量化配置
    moe_layers: List[Any] = None         # MoE 层

通过 set_forward_context() context manager 在执行期间设置:

1
2
3
4
5
6
7
with set_forward_context(
    forward_batch,
    self.attention_layers,
    self.quant_config,
    self.moe_layers,
):
    output = self.model_runner.model.forward(...)

3.2.3 全局内存池

跨 runners 共享的内存管理

1
2
3
4
5
6
7
8
9
# From piecewise_cuda_graph_runner.py:106-115
global_graph_memory_pool = None

def get_global_graph_memory_pool():
    return global_graph_memory_pool

def set_global_graph_memory_pool(val):
    global global_graph_memory_pool
    global_graph_memory_pool = val

优势:

  • 多个 runner 共享,减少内存占用
  • 大 graph 优先捕获,小 graph 复用内存
  • 使用 weak reference 释放不用的输出

3.3 Piece 生命周期

1
2
3
创建 → Warmup → Capture (torch.compile + CUDA graph) → Replay → 清理
   ↓       ↓          ↓                                 ↓         ↓
初始化  首次运行  编译+捕获                            运行时    GC

详细流程:

  1. 创建 (PiecewiseCudaGraphRunner.__init__)
    • 分配输入缓冲区
    • 初始化全局内存池
    • 设置配置
  2. Warmup (warmup_torch_compile)
    • 执行一次 forward pass
    • 触发 torch.compile
    • 建立计算图
  3. Capture (capture_one_batch_size)
    • 第一次调用:Warmup
    • 第二次调用:捕获 CUDA Graph
    • 存储到 ConcreteSizeEntry
  4. Replay (replay)
    • 查找合适的 graph
    • 复制输入数据
    • 调用 cudagraph.replay()
  5. 清理
    • Weak reference 自动释放
    • GC 回收临时对象

4. OP 隔离机制详解

4.1 隔离策略

基于 operation 的 graph 切分

通过 split_ops 配置指定需要在 graph 边界处切分的操作:

1
2
3
4
5
6
# From compilation_config.py:18-22
self.split_ops = [
    "sglang.unified_attention_with_output",  # 统一注意力计算
    "sglang.gdn_with_output",                # GDN (Grouped Distributed Norm)
    "sglang.inplace_all_reduce",             # 分布式通信
]

4.2 隔离实现方式

4.2.1 Graph 切分

使用 torch.fx.profiler.split_module()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# From backend.py:213-256
def split_graph(
    graph: fx.GraphModule, ops: list[str]
) -> tuple[fx.GraphModule, list[SplitItem]]:
    # 为每个 node 分配 subgraph ID
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []

    for node in graph.graph.nodes:
        if node.op == "call_function" and str(node.target) in ops:
            # 在 split op 处增加 subgraph_id,创建边界
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # 使用 keep_original_order=True 保持语义正确
    split_gm = torch.fx.passes.split_module.split_module(
        graph, None,
        lambda node: node_to_subgraph_id[node],
        keep_original_order=True  # 重要!保持顺序
    )

    return split_gm, outputs

4.2.2 编译阶段

独立编译每个 subgraph

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# From backend.py:396-472
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
    # 1. 切分图
    self.split_gm, self.piecewise_graphs = split_graph(
        graph,
        self.compile_config.split_ops,
    )

    # 2. 选择需要编译的 submodules
    submod_names_to_compile = [
        item.submod_name
        for item in self.piecewise_graphs
        if not item.is_splitting_graph  # 不编译 split op 本身
    ]

    # 3. 使用 PiecewiseCompileInterpreter 编译
    PiecewiseCompileInterpreter(
        self.split_gm,
        submod_names_to_compile,
        self.inductor_config,
        self.graph_pool,
        self.compile_config,
        self,
    ).run(*example_inputs)

4.3 默认隔离的 OP

OP 名称 功能 隔离原因
sglang.unified_attention_with_output 统一注意力计算 包含复杂控制流,shape 相关逻辑
sglang.gdn_with_output Grouped Distributed Norm 分布式同步点,需要独立执行
sglang.inplace_all_reduce 分布式 all-reduce 通信操作,需要同步
sglang.moe_forward_piecewise_cuda_graph_impl MoE 前向传播 (DeepEP/Mooncake) 动态路由,控制流复杂

4.4 执行层面影响

4.4.1 执行边界

Split ops 处形成执行边界

1
2
3
[Subgraph 0] → unified_attention (split op) → [Subgraph 1] → gdn (split op) → [Subgraph 2]
                ↑                                        ↑
             执行边界                                  执行边界

在每个边界处:

  1. 退出 torch.compile 模式
  2. 执行 custom op
  3. 重新进入 torch.compile 模式

4.4.2 同步点控制

CustomOp 通过 enter/leave 控制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# From custom_op.py:35-69
class CustomOp(nn.Module):
    def enter_torch_compile(self, num_tokens: int):
        # 进入 torch.compile 模式
        if self.is_torch_compile:
            return  # 已经在编译模式中
        self._original_forward_method = self._forward_method
        self._forward_method = self.forward_native  # 使用原生实现
        self.is_torch_compile = True

    def leave_torch_compile(self):
        # 离开 torch.compile 模式
        if not self.is_torch_compile:
            return
        self._forward_method = self._original_forward_method
        self.is_torch_compile = False

4.4.3 内存隔离

每个 piece 使用独立内存分配

虽然使用全局 graph pool,但每个 piece 的输入输出内存是独立的:

1
2
3
4
# 在捕获时,使用 weak ref 减少 memory hold
if self.is_last_graph:
    output = weak_ref_tensors(output)  # 仅对最后一个 graph
entry.output = weak_ref_tensors(output)  # 所有 graph 都使用 weak ref

4.5 隔离机制示例

以 Attention 为例:

1
2
3
4
5
6
7
8
9
10
原始计算图:
  input → layernorm → attention → mlp → output

切分后 (split attention):
  [Subgraph 0: layernorm] → attention (split) → [Subgraph 1: mlp]

执行流程:
  1. Subgraph 0 使用 CUDA Graph replay
  2. Attention 退出 graph,正常执行
  3. Subgraph 1 使用 CUDA Graph replay

5. 代码详细解析

5.1 关键文件清单

1
2
3
4
5
6
7
8
9
10
11
12
13
python/sglang/srt/
├── compilation/
│   ├── compilation_config.py              - 配置管理 (7-38 行)
│   ├── backend.py                         - 后端工厂和 graph 切分 (29-473 行)
│   ├── cuda_piecewise_backend.py          - CUDA 实现 (40-210 行)
│   ├── npu_piecewise_backend.py           - NPU 实现
│   ├── piecewise_context_manager.py       - 上下文管理 (1-99 行)
│   ├── compiler_interface.py              - 编译器接口
│   └── weak_ref_tensor.py                 - Weak reference 工具
├── model_executor/
│   ├── piecewise_cuda_graph_runner.py     - 主执行器 (127-741 行)
│   └── model_runner.py                    - 集成入口
└── custom_op.py                            - 自定义操作基类 (26-109 行)

5.2 核心代码路径

5.2.1 初始化流程

1
2
3
4
5
6
7
8
9
10
11
ModelRunner.init_piecewise_cuda_graphs()

PiecewiseCudaGraphRunner.__init__()  (137 行)

install_torch_compiled()             (246-252 行)

warmup_torch_compile()               (265, 282 行)

capture()                            (274 行)

capture_one_batch_size()             (414, 416 行)

关键代码片段:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# From piecewise_cuda_graph_runner.py:137-278
def __init__(self, model_runner: ModelRunner):
    # 1. 配置初始化
    self.compile_config = CompilationConfig(
        self.model_runner.server_args.piecewise_cuda_graph_tokens,
        self.model_runner.server_args.piecewise_cuda_graph_compiler,
    )

    # 2. 内存池初始化
    if get_global_graph_memory_pool() is None:
        set_global_graph_memory_pool(self.device_module.graph_pool_handle())
    set_graph_pool_id(get_global_graph_memory_pool())

    # 3. Torch.compile 安装
    with enable_piecewise_cuda_graph():
        install_torch_compiled(
            patched_model,
            fullgraph=True,
            compile_config=self.compile_config,
            graph_pool=get_global_graph_memory_pool(),
        )

        # 4. Warmup 和 Capture
        with set_compiled(True):
            for num_tokens in reversed(self.capture_num_tokens):
                self.warmup_torch_compile(num_tokens)

        self.capture()  # 5. 捕获 CUDA Graph

5.2.2 执行流程

1
2
3
4
5
6
7
8
9
10
11
12
13
ModelRunner.forward()

can_run()                           (369 行)

replay()                            (670 行)

replay_prepare()                    (539, 677 行)

model.forward()                     (686 行)

CUDAPiecewiseBackend.__call__()     (cuda_piecewise_backend.py:107)

cudagraph.replay()                  (cuda_piecewise_backend.py:208)

关键代码片段:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# From piecewise_cuda_graph_runner.py:369-380
def can_run(self, forward_batch: ForwardBatch):
    num_tokens = len(forward_batch.input_ids)

    # 检查 logprob 是否在中间状态
    if forward_batch.return_logprob:
        for start_len, seq_len in zip(...):
            if start_len is not None and start_len < seq_len:
                return False  # 不支持 logprob 中间状态

    # 检查是否在捕获范围内
    if num_tokens <= self.max_num_tokens:
        return True
    return False
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# From piecewise_cuda_graph_runner.py:539-668
def replay_prepare(self, forward_batch: ForwardBatch, **kwargs):
    num_tokens = len(forward_batch.input_ids)

    # Binary search 查找合适的 size
    index = bisect.bisect_left(self.capture_num_tokens, num_tokens)
    static_num_tokens = self.capture_num_tokens[index]

    # 复制输入数据
    self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
    self.positions[:num_tokens].copy_(forward_batch.positions)
    # ...

    # 创建静态 ForwardBatch
    static_forward_batch = ForwardBatch(
        input_ids=self.input_ids[:static_num_tokens],
        positions=self.positions[:static_num_tokens],
        # ...
    )
    return static_forward_batch

5.2.3 编译流程

1
2
3
4
5
6
7
8
9
SGLangBackend.__call__()              (backend.py:396)

split_graph()                         (backend.py:424, 213)

PiecewiseCompileInterpreter.run()     (backend.py:443, 287)

CompilerManager.compile()             (backend.py:312, 128)

CUDAPiecewiseBackend.__call__()       (cuda_piecewise_backend.py:107)

关键代码片段:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# From backend.py:265-336
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    def call_module(self, target, args, kwargs):
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
            # 编译该 subgraph
            compiled_graph = self.sglang_backend.compiler_manager.compile(
                submod,
                args,
                self.inductor_config,
                graph_index=index,
                runtime_shape=None,  # 通用 shape
            )

            # 创建后端
            self.module.__dict__[target] = make_backend(
                submod,
                self.compile_config,
                self.inductor_config,
                self.graph_pool,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                compiled_graph,
                self.sglang_backend,
            )

        return output

5.3 重要函数说明

函数 文件 行号 功能
split_graph() backend.py 213 按指定 ops 切分计算图
make_compiler() backend.py 29 创建编译器适配器
make_backend() backend.py 38 创建 piecewise 后端
freeze_gc() piecewise_cuda_graph_runner.py 67 冻结 GC 优化捕获
patch_model() piecewise_cuda_graph_runner.py 96 为 CustomOp 打补丁
enable_piecewise_cuda_graph() piecewise_context_manager.py 35 启用 piecewise 模式
set_forward_context() piecewise_context_manager.py 83 设置执行上下文

5.4 配置参数

启动参数 (通过 server_args 传入):

参数 类型 默认值 说明
--enable-piecewise-cuda-graph bool False 启用 piecewise CUDA Graph
--piecewise-cuda-graph-tokens list[int] [1, 2, 4, 8, 16, 32, 64, 128, 256] 要捕获的 token sizes
--piecewise-cuda-graph-compiler str “eager” 编译器类型 (eager/inductor)
--enable-torch-compile-debug-mode bool False 启用调试模式

环境变量:

变量 说明
SGLANG_CACHE_DIR 缓存目录路径 (默认 ~/.cache/sglang/)

6. 总结与工程视角评价

6.1 优点

6.1.1 性能优化

  1. CUDA Graph 减少 kernel launch 开销
    • CPU 到 GPU 的 kernel launch 延迟通常在 5-20μs
    • 对于小 batch (1-16 tokens),单个 kernel 可能仅耗时 10-50μs
    • CUDA Graph 将多次 launch 合并为一次,显著提升小 batch 性能
  2. Torch.compile 算子融合
    • 自动优化算子融合策略
    • 减少内存读写次数
    • 提升计算密度
  3. 全局 memory pool + weak reference
    • 多 graph 共享内存,降低峰值占用
    • Weak reference 及时释放无用输出

6.1.2 灵活性

  1. 多 size 覆盖
    • 支持任意配置的 capture sizes
    • 运行时动态选择最合适的 graph
  2. 硬件抽象
    • 统一接口支持 CUDA 和 NPU
    • 易于扩展到其他硬件
  3. OP 级别隔离
    • 灵活配置 split ops
    • 支持复杂控制流的操作

6.1.3 工程实践

  1. 完整的调试支持
    • Debug mode 验证输入地址一致性
    • Graph 导出和可视化
    • 详细的日志和计数器
  2. 缓存机制
    • 编译结果持久化
    • 减少启动时间

6.2 局限

6.2.1 功能冲突

冲突项 说明
Pipeline Parallel (PP) 不支持 PP,因为需要跨 stage 的状态传递
部分 torch.compile 功能 某些 dynamo 特性与 fullgraph 冲突
Logprob 中间状态 return_logprob 时需要序列稳定在中间状态

6.2.2 MOE 限制

不支持某些 MOE backend:

  • DeepEP: 需要添加 moe_forward_piecewise_cuda_graph_impl 到 split_ops
  • Mooncake: 同上

原因:动态路由导致控制流复杂,难以静态捕获

6.2.3 模型要求

  • 仅支持标准 GQA attention
  • 不支持非标准的 attention 变体
  • 要求特定的 memory layout

6.2.4 内存开销

需要预先分配 memory pool:

  • 对大模型/大 batch 有压力
  • 需要调整 --mem-fraction-static 参数
  • 可能需要减少 --piecewise-cuda-graph-max-tokens

6.3 可扩展性

6.3.1 新增 split op

步骤:

  1. CompilationConfig.split_ops 添加 op 名称
  2. 确保 op 继承自 CustomOp
  3. 实现 forward_native() 方法

示例:

1
2
3
4
5
6
7
# compilation_config.py
self.split_ops = [
    "sglang.unified_attention_with_output",
    "sglang.gdn_with_output",
    "sglang.inplace_all_reduce",
    "sglang.my_custom_op",  # 新增
]

6.3.2 新增 backend

实现接口:

1
2
3
4
5
6
7
8
9
10
11
12
13
class MyPiecewiseBackend:
    def __init__(self, graph, compile_config, inductor_config,
                 graph_pool, piecewise_compile_index, ...):
        # 初始化
        pass

    def __call__(self, *args):
        # 编译和捕获逻辑
        pass

    def capture(self):
        # 硬件特定的捕获逻辑
        pass

6.3.3 自定义 compiler

实现 CompilerInterface

1
2
3
4
5
6
7
8
9
10
from sglang.srt.compilation.compiler_interface import CompilerInterface

class MyCompiler(CompilerInterface):
    def compile(self, graph, example_inputs, config):
        # 自定义编译逻辑
        pass

    def load(self, handle, graph, example_inputs):
        # 加载缓存的编译结果
        pass

6.4 最佳实践

  1. Capture sizes 选择

    • 根据实际 workload 的 token 分布选择
    • 覆盖 80-90% 的请求
    • 避免过多 sizes 导致内存压力
  2. 内存调优

    1
    2
    3
    4
    5
    # 降低静态内存比例
    --mem-fraction-static 0.8
    
    # 减少最大 capture size
    --piecewise-cuda-graph-max-tokens 512
  3. 编译器选择

    • eager: 更快启动,兼容性更好
    • inductor: 更好性能,但编译时间更长
  4. 调试技巧

    1
    2
    3
    4
    5
    # 启用调试模式
    --enable-torch-compile-debug-mode
    
    # 检查导出的 graph
    ls ~/.cache/sglang/torch_compile_cache/

7. 文档附录

7.1 配置参数完整说明

参数 类型 默认值 范围 说明
enable_piecewise_cuda_graph bool False - 启用 piecewise CUDA Graph
piecewise_cuda_graph_tokens list[int] [1,2,4,8,16,32,64,128,256] 1-2048 要捕获的 token sizes
piecewise_cuda_graph_compiler str “eager” eager/inductor 编译器类型
enable_torch_compile_debug_mode bool False - 启用调试模式
enable_cudagraph_gc bool False - 捕获期间启用 GC
mem_fraction_static float 0.9 0.1-0.95 静态内存比例

7.2 关键函数调用序列图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
启动阶段
ModelRunner.__init__
  └─> init_piecewise_cuda_graphs
       ├─> PiecewiseCudaGraphRunner.__init__
       │    ├─> CompilationConfig.__init__
       │    ├─> set_global_graph_memory_pool
       │    └─> install_torch_compiled
       │         └─> SGLangBackend.__call__
       │              ├─> split_graph
       │              └─> PiecewiseCompileInterpreter.run
       │                   └─> CompilerManager.compile
       ├─> warmup_torch_compile (for each size)
       └─> capture
            └─> capture_one_batch_size (for each size)
                 └─> CUDAPiecewiseBackend.__call__
                      └─> torch.cuda.graph

推理阶段
ModelRunner.forward
  └─> can_run
       └─> replay
            ├─> replay_prepare
            │    └─> bisect.bisect_left (查找 graph)
            └─> model.forward
                 └─> CUDAPiecewiseBackend.__call__
                      └─> cudagraph.replay

7.3 性能对比数据

理论分析:

场景 Eager Torch.compile Piecewise (Eager) Piecewise (Inductor)
小 batch (1-4 tokens) 1.0x 1.2-1.5x 1.5-2.0x 1.8-2.5x
中 batch (8-32 tokens) 1.0x 1.3-1.6x 1.3-1.7x 1.6-2.0x
大 batch (64+ tokens) 1.0x 1.2-1.4x 1.1-1.3x 1.3-1.6x

启动开销:

模式 启动时间
Eager ~1s
Torch.compile (eager) ~10-30s
Torch.compile (inductor) ~30-60s
Piecewise +10-20% 上述时间

7.4 常见问题排查

问题 1: Capture 失败

错误信息:

1
RuntimeError: CUDA graph capture failed

解决方案:

1
2
3
4
5
6
7
8
# 1. 降低静态内存比例
--mem-fraction-static 0.8

# 2. 减少最大 capture size
--piecewise-cuda-graph-max-tokens 512

# 3. 禁用 piecewise
--disable-piecewise-cuda-graph

问题 2: 内存不足

错误信息:

1
CUDA out of memory

解决方案:

1
2
3
4
5
6
7
8
# 1. 减少捕获 sizes
--piecewise-cuda-graph-tokens 1,2,4,8,16,32,64

# 2. 降低静态内存
--mem-fraction-static 0.7

# 3. 启用 GC
--enable-cudagraph-gc

问题 3: 性能下降

可能原因:

  1. Capture sizes 不匹配 workload
  2. 频繁 miss cache
  3. 过大的 graph 导致内存压力

解决方案:

1
2
3
4
5
6
7
8
# 1. 调整 capture sizes
--piecewise-cuda-graph-tokens 16,32,64,128

# 2. 使用 inductor 编译器
--piecewise-cuda-graph-compiler inductor

# 3. 清理缓存
rm -rf ~/.cache/sglang/torch_compile_cache/

7.5 参考资料

源代码:

相关文档:

设计参考: