pd分离最佳配比实验方法

概述

在 PD(Prefill-Decode)分离架构中,Prefill 和 Decode 的性能是完全独立的。TTFT(Time To First Token)和 TPOT(Time Per Output Token)的指标完全取决于输入/输出吞吐和具体的机器配比。本文介绍如何分别测试和确定 Prefill 与 Decode 的最佳性能配比。

测试 Prefill 性能

测试原理

Prefill 性能测试相对简单,因为 Decode 性能通常会远强于 Prefill 性能,因此只需要一个普通的 Decode 节点即可。

测试方法

  1. 设置 Decode 节点:启动一个普通的 Decode 节点
  2. 配置 Bench 参数
    • 输出长度设置为 1
    • 这样可以完整包含 Prefill 的计算和 PD 传输过程
    • 设置合理的并发数
  3. 示例配置
    • 输入:8k tokens
    • 输出:1 token
    • 并发:128

这样可以通过测试计算出 Prefill 最大的 RPS(Requests Per Second)。

多组测试维度

建议进行多组测试以找到最佳配置:

  • 不同卡数:8 卡、16 卡等
  • 不同 DP:TP 配比
  • 是否开启 DeepSpeed 等

通过这些测试可以得出单个 Prefill 节点在满足 SLO(Service Level Objective)情况下的最佳性能以及对应的最大 RPS

测试 Decode 性能

测试难点

Decode 性能测试要麻烦很多,主要原因是 Decode 性能太强,需要很多个 Prefill 节点也未必能把 Decode 打满

解决方案:强制降速法

参考 sglang/issues/6017,可以让 Decode 节点在每一次 forward 时 sleep 若干秒,强制降低 Decode 的推理速度。

关键点: 只阻塞推理过程,其它速度不变(如接收 KVCache 等)

具体操作步骤

步骤 1:启动 PD 节点和 Minilb

1
2
# 启动好 PD 节点后启动 minilb
# 因为 minilb 简单,没有 fake request 和 health check

步骤 2:给 Decode 节点增加延迟

1
2
3
curl -H "Content-Type: application/json" \
  -d '{"forward_sleep_time": 90.0}' \
  -X POST "http://YOUR_FIRST_DECODE_NODE_IP:30000/slow_down"

步骤 3:发送大量请求

1
2
# 发送大量请求给 Prefill 节点
# 等待 5-10 分钟,直到 Decode 节点堆积请求大于 --max-running-requests

步骤 4:移除延迟并观察

1
2
3
curl -H "Content-Type: application/json" \
  -d '{"forward_sleep_time": null}' \
  -X POST "http://YOUR_FIRST_DECODE_NODE_IP:30000/slow_down"

观察 Decode 日志,确认是否以期望的并发进行推理。

日志分析

示例日志:

1
[2025-10-10 08:50:41 DP7 TP7 EP7] Decode batch. #running-req: 64, #token: 508158, token usage: 0.63, pre-allocated usage: 0.00, #retracted-req: 0, cuda graph: True, gen throughput (token/s): 2231.5, #queue-req: 0,

关注两个核心指标

  1. #running-req:需要是我们期望的测试并发
  2. gen throughput (token/s) :实际推理时的 Decode 吞吐

有了吞吐数据,就可以计算:

  • TPOT(Time Per Output Token)
  • 该并发下单卡的 Decode 性能(吞吐 ÷ 卡数)

容量规划注意事项

启动时需要设置:

  • --max-running-requests:保证以期望的并发运行
  • KVCache 总容量:需要足够支撑测试并发

示例:输入 8k,输出 2k,单个请求需要 10k tokens 的 KVCache,希望测试并发 128,至少需要 128 × 10k = 1.28M tokens 的 KVCache 容量

通过以上测试可以得出最佳的 Decode 配置,满足 SLO 对应的最佳性能和最大 RPS

端到端测试

完成 Prefill 和 Decode 的独立性能测试后,需要进行总体的端到端测试。

配比计算方法

假设在满足 SLO 的前提下:

  • Prefill:16 卡可以达到 RPS = 5
  • Decode:16 卡可以达到 RPS = 10

则最佳 PD 配比为:

1
10 : 5 = 2 : 1

因此需要 2P1D(2 个 Prefill 节点,1 个 Decode 节点)的配置,可以达到 RPS = 10。

测试步骤

  1. 启动服务:启动 2P1D 和 Router
  2. 执行端到端测试
1
2
3
4
bench_serving \
  --concurrency <并发数> \
  --request-rate 0.1 \
  # ... 其他参数

关键参数

  • --concurrency:控制并发,模拟真实业务场景
  • --request-rate 0.1:即 1/RPS,控制请求速率

常见问题和注意事项

PD 传输超时

问题:大并发下请求可能失败

解决:增加 PD 传输超时参数

系统句柄数限制

问题:网络并发受限

解决

1
ulimit -n 65536

测试工具选择

建议

  • 测试 Decode 时最好用 minilb
  • 原因:没有 fake request 和 health check

跳过 Warmup

如果使用 bench_serving 测试 Decode:

1
bench_serving --skip-warmup

原因:因为推理被延时了,warmup 的请求会等待很长时间,不需要 warmup。

Minilb 并发配置

默认仓库的 minilb 需要修改源码来调整 HTTP 并发数。

注意:现在 minilb 已移入到 gateway,较难使用,建议使用修改后的独立版本。

修改后的minilb代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import argparse
import dataclasses


@dataclasses.dataclass
class LBArgs:
    rust_lb: bool = False
    host: str = "0.0.0.0"
    port: int = 8000
    policy: str = "random"
    prefill_infos: list = dataclasses.field(default_factory=list)
    decode_infos: list = dataclasses.field(default_factory=list)
    log_interval: int = 5
    timeout: int = 600
    connector_limit: int = 10000
    connector_limit_per_host: int = 10000

    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "--rust-lb",
            action="store_true",
            help="Use Rust load balancer",
        )
        parser.add_argument(
            "--host",
            type=str,
            default=LBArgs.host,
            help=f"Host to bind the server (default: {LBArgs.host})",
        )
        parser.add_argument(
            "--port",
            type=int,
            default=LBArgs.port,
            help=f"Port to bind the server (default: {LBArgs.port})",
        )
        parser.add_argument(
            "--policy",
            type=str,
            default=LBArgs.policy,
            choices=["random", "po2"],
            help=f"Policy to use for load balancing (default: {LBArgs.policy})",
        )
        parser.add_argument(
            "--prefill",
            type=str,
            default=[],
            nargs="+",
            help="URLs for prefill servers",
        )
        parser.add_argument(
            "--decode",
            type=str,
            default=[],
            nargs="+",
            help="URLs for decode servers",
        )
        parser.add_argument(
            "--prefill-bootstrap-ports",
            type=int,
            nargs="+",
            help="Bootstrap ports for prefill servers",
        )
        parser.add_argument(
            "--log-interval",
            type=int,
            default=LBArgs.log_interval,
            help=f"Log interval in seconds (default: {LBArgs.log_interval})",
        )
        parser.add_argument(
            "--timeout",
            type=int,
            default=LBArgs.timeout,
            help=f"Timeout in seconds (default: {LBArgs.timeout})",
        )
        parser.add_argument(
            "--connector-limit",
            type=int,
            default=LBArgs.connector_limit,
            help=f"Total connection pool size (default: {LBArgs.connector_limit})",
        )
        parser.add_argument(
            "--connector-limit-per-host",
            type=int,
            default=LBArgs.connector_limit_per_host,
            help=f"Connections per host (default: {LBArgs.connector_limit_per_host})",
        )

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs":
        bootstrap_ports = args.prefill_bootstrap_ports
        if bootstrap_ports is None:
            bootstrap_ports = [None] * len(args.prefill)
        elif len(bootstrap_ports) == 1:
            bootstrap_ports = bootstrap_ports * len(args.prefill)
        else:
            if len(bootstrap_ports) != len(args.prefill):
                raise ValueError(
                    "Number of prefill URLs must match number of bootstrap ports"
                )

        prefill_infos = [
            (url, port) for url, port in zip(args.prefill, bootstrap_ports)
        ]

        return cls(
            rust_lb=args.rust_lb,
            host=args.host,
            port=args.port,
            policy=args.policy,
            prefill_infos=prefill_infos,
            decode_infos=args.decode,
            log_interval=args.log_interval,
            timeout=args.timeout,
            connector_limit=args.connector_limit,
            connector_limit_per_host=args.connector_limit_per_host,
        )

    def __post_init__(self):
        if not self.rust_lb:
            assert (
                self.policy == "random"
            ), "Only random policy is supported for Python load balancer"


def main():
    parser = argparse.ArgumentParser(
        description="PD Disaggregation Load Balancer Server"
    )
    LBArgs.add_cli_args(parser)
    args = parser.parse_args()
    lb_args = LBArgs.from_cli_args(args)

    if lb_args.rust_lb:
        from sgl_pdlb._rust import LoadBalancer as RustLB

        RustLB(
            host=lb_args.host,
            port=lb_args.port,
            policy=lb_args.policy,
            prefill_infos=lb_args.prefill_infos,
            decode_infos=lb_args.decode_infos,
            log_interval=lb_args.log_interval,
            timeout=lb_args.timeout,
        ).start()
    else:
        # from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
        from mini_lb import PrefillConfig, run

        prefill_configs = [
            PrefillConfig(url, port) for url, port in lb_args.prefill_infos
        ]
        run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port,
            connector_limit=lb_args.connector_limit,
            connector_limit_per_host=lb_args.connector_limit_per_host)


if __name__ == "__main__":
    main()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
  """
Minimal HTTP load balancer for prefill and decode servers for testing.
"""

import asyncio
import dataclasses
import logging
import random
import resource
import urllib
from itertools import chain
from typing import List, Optional

import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse

# from sglang.srt.disaggregation.utils import PDRegistryRequest


def setup_logger():
    logger = logging.getLogger("pdlb")
    logger.setLevel(logging.INFO)

    formatter = logging.Formatter(
        "[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    return logger


def increase_file_descriptor_limits():
    """Increase file descriptor limits to handle many connections."""
    try:
        # Get current limits
        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)

        # Recommended minimum for high-concurrency applications
        recommended_min = 65536

        if soft < recommended_min:
            # Try to increase to recommended minimum or hard limit if lower
            new_limit = min(recommended_min, hard)
            try:
                resource.setrlimit(resource.RLIMIT_NOFILE, (new_limit, hard))
                logger.info(f"Increased file descriptor limit from {soft} to {new_limit}")
            except (ValueError, OSError) as e:
                logger.warning(f"Could not increase file descriptor limit: {e}")
                logger.info(f"Current limit: {soft}, Hard limit: {hard}")

                # Provide system-level commands for user to run
                logger.info("To increase system limits permanently, run:")
                logger.info("  echo '* soft nofile 65536' >> /etc/security/limits.conf")
                logger.info("  echo '* hard nofile 65536' >> /etc/security/limits.conf")
                logger.info("Or for temporary increase: ulimit -n 65536")
        else:
            logger.info(f"File descriptor limit is already sufficient: {soft}")

    except (ImportError, OSError) as e:
        logger.warning(f"Could not check/increase file descriptor limits: {e}")


logger = setup_logger()


@dataclasses.dataclass
class PrefillConfig:
    url: str
    bootstrap_port: Optional[int] = None


class MiniLoadBalancer:
    def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str],
                 connector_limit=10000, connector_limit_per_host=1000):
        self.prefill_configs = prefill_configs
        self.prefill_servers = [p.url for p in prefill_configs]
        self.decode_servers = decode_servers
        self.connector_limit = connector_limit
        self.connector_limit_per_host = connector_limit_per_host

    def add_prefill_server(self, new_prefill_config: PrefillConfig):
        self.prefill_configs.append(new_prefill_config)
        self.prefill_servers.append(new_prefill_config.url)

    def add_decode_server(self, new_decode_server: str):
        self.decode_servers.append(new_decode_server)

    def select_pair(self):
        # TODO: return some message instead of panic
        assert len(self.prefill_configs) > 0, "No prefill servers available"
        assert len(self.decode_servers) > 0, "No decode servers available"

        prefill_config = random.choice(self.prefill_configs)
        decode_server = random.choice(self.decode_servers)
        return prefill_config.url, prefill_config.bootstrap_port, decode_server

    async def generate(
        self, modified_request, prefill_server, decode_server, endpoint
    ) -> ORJSONResponse:
        assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"

        # Create connector with increased connection limits
        connector = aiohttp.TCPConnector(
            limit=self.connector_limit,  # Total connection pool size
            limit_per_host=self.connector_limit_per_host,  # Connections per host
            force_close=False,
            enable_cleanup_closed=True,
        )

        async with aiohttp.ClientSession(
            connector=connector,
            timeout=aiohttp.ClientTimeout(
                total=3600
            )  # Add timeout for request reliability
        ) as session:
            tasks = [
                session.post(f"{prefill_server}/{endpoint}", json=modified_request),
                session.post(f"{decode_server}/{endpoint}", json=modified_request),
            ]

            # Wait for both responses to complete. Prefill should end first.
            prefill_response, decode_response = await asyncio.gather(*tasks)

            if "return_logprob" in modified_request:

                prefill_json = await prefill_response.json()
                ret_json = await decode_response.json()

                # merge `meta_info.input_token_logprobs` from prefill to decode
                if "meta_info" in ret_json:
                    if "input_token_logprobs" in ret_json["meta_info"]:
                        ret_json["meta_info"]["input_token_logprobs"] = (
                            prefill_json["meta_info"]["input_token_logprobs"]
                            + ret_json["meta_info"]["input_token_logprobs"]
                        )
            else:
                ret_json = await decode_response.json()

            return ORJSONResponse(
                content=ret_json,
                status_code=decode_response.status,
            )

    async def generate_stream(
        self, modified_request, prefill_server, decode_server, endpoint="generate"
    ):
        assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"

        async def stream_results():
            # Create connector with increased connection limits
            connector = aiohttp.TCPConnector(
                limit=10000,  # Total connection pool size
                limit_per_host=1000,  # Connections per host
                force_close=False,
                enable_cleanup_closed=True,
            )

            async with aiohttp.ClientSession(
                connector=connector,
                timeout=aiohttp.ClientTimeout(
                    total=3600
                )  # Add timeout for request reliability
            ) as session:
                # Create the tasks for both prefill and decode requests
                tasks = [
                    session.post(f"{prefill_server}/{endpoint}", json=modified_request),
                    session.post(f"{decode_server}/{endpoint}", json=modified_request),
                ]
                # Wait for both responses to complete. Since this is streaming, they return immediately.
                prefill_response, decode_response = await asyncio.gather(*tasks)

                if modified_request.get("return_logprob", False):
                    prefill_chunks = []
                    async for chunk in prefill_response.content:
                        prefill_chunks.append(chunk)

                    first_prefill_chunk = (
                        prefill_chunks[0].decode("utf-8")[5:].strip("\n")
                    )
                    first_prefill_chunk_json = orjson.loads(first_prefill_chunk)

                    async for chunk in decode_response.content:
                        # Note: This is inefficient
                        # merge prefill input_token_logprobs, output_token_logprobs to decode
                        decoded_chunk = chunk.decode("utf-8")
                        if (
                            decoded_chunk
                            and decoded_chunk.startswith("data:")
                            and "[DONE]" not in decoded_chunk
                        ):
                            ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
                            ret_json["meta_info"]["input_token_logprobs"] = (
                                first_prefill_chunk_json["meta_info"][
                                    "input_token_logprobs"
                                ]
                                + ret_json["meta_info"]["input_token_logprobs"]
                            )

                            yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
                        else:
                            yield chunk
                else:
                    async for chunk in decode_response.content:
                        yield chunk

        return StreamingResponse(
            stream_results(),
            media_type="text/event-stream",
        )


app = FastAPI()
load_balancer: Optional[MiniLoadBalancer] = None


@app.get("/health")
async def health_check():
    return Response(status_code=200)


@app.get("/health_generate")
async def health_check():
    prefill_servers, decode_servers = (
        load_balancer.prefill_servers,
        load_balancer.decode_servers,
    )

    # Create connector with increased connection limits
    connector = aiohttp.TCPConnector(
        limit=load_balancer.connector_limit,  # Total connection pool size
        limit_per_host=load_balancer.connector_limit_per_host,  # Connections per host
        force_close=False,
        enable_cleanup_closed=True,
    )

    async with aiohttp.ClientSession(connector=connector) as session:
        # Create the tasks
        tasks = []
        for server in chain(prefill_servers, decode_servers):
            tasks.append(session.post(f"{server}/health_generate"))
        for i, response in enumerate(asyncio.as_completed(tasks)):
            await response
    return Response(status_code=200)


@app.post("/flush_cache")
async def flush_cache():
    prefill_servers, decode_servers = (
        load_balancer.prefill_servers,
        load_balancer.decode_servers,
    )

    # Create connector with increased connection limits
    connector = aiohttp.TCPConnector(
        limit=load_balancer.connector_limit,  # Total connection pool size
        limit_per_host=load_balancer.connector_limit_per_host,  # Connections per host
        force_close=False,
        enable_cleanup_closed=True,
    )

    async with aiohttp.ClientSession(connector=connector) as session:
        # Create the tasks
        tasks = []
        for server in chain(prefill_servers, decode_servers):
            tasks.append(session.post(f"{server}/flush_cache"))
        for i, response in enumerate(asyncio.as_completed(tasks)):
            await response
    return Response(status_code=200)


@app.get("/get_server_info")
async def get_server_info():
    prefill_servers, decode_servers = (
        load_balancer.prefill_servers,
        load_balancer.decode_servers,
    )
    prefill_infos = []
    decode_infos = []

    # Create connector with increased connection limits
    connector = aiohttp.TCPConnector(
        limit=load_balancer.connector_limit,  # Total connection pool size
        limit_per_host=load_balancer.connector_limit_per_host,  # Connections per host
        force_close=False,
        enable_cleanup_closed=True,
    )

    async with aiohttp.ClientSession(connector=connector) as session:
        for server in chain(prefill_servers):
            server_info = await session.get(f"{server}/get_server_info")
            prefill_infos.append(await server_info.json())
        for server in chain(decode_servers):
            server_info = await session.get(f"{server}/get_server_info")
            decode_infos.append(await server_info.json())

    return {"prefill": prefill_infos, "decode": decode_infos}


@app.get("/get_model_info")
async def get_model_info():
    # Dummy model information
    model_info = {
        "model_path": "/path/to/dummy/model",
        "tokenizer_path": "/path/to/dummy/tokenizer",
        "is_generation": True,
        "preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128},
    }
    return ORJSONResponse(content=model_info)


@app.post("/generate")
async def handle_generate_request(request_data: dict):
    prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()

    # Parse and transform prefill_server for bootstrap data
    parsed_url = urllib.parse.urlparse(prefill_server)
    hostname = parsed_url.hostname
    modified_request = request_data.copy()

    batch_size = _get_request_batch_size(modified_request)
    if batch_size is not None:
        modified_request.update(
            {
                "bootstrap_host": [hostname] * batch_size,
                "bootstrap_port": [bootstrap_port] * batch_size,
                "bootstrap_room": [
                    _generate_bootstrap_room() for _ in range(batch_size)
                ],
            }
        )
    else:
        modified_request.update(
            {
                "bootstrap_host": hostname,
                "bootstrap_port": bootstrap_port,
                "bootstrap_room": _generate_bootstrap_room(),
            }
        )

    if request_data.get("stream", False):
        return await load_balancer.generate_stream(
            modified_request, prefill_server, decode_server, "generate"
        )
    else:
        return await load_balancer.generate(
            modified_request, prefill_server, decode_server, "generate"
        )


async def _forward_to_backend(request_data: dict, endpoint_name: str):
    prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()

    # Parse and transform prefill_server for bootstrap data
    parsed_url = urllib.parse.urlparse(prefill_server)
    hostname = parsed_url.hostname
    modified_request = request_data.copy()
    modified_request.update(
        {
            "bootstrap_host": hostname,
            "bootstrap_port": bootstrap_port,
            "bootstrap_room": _generate_bootstrap_room(),
        }
    )

    if request_data.get("stream", False):
        return await load_balancer.generate_stream(
            modified_request,
            prefill_server,
            decode_server,
            endpoint=endpoint_name,
        )
    else:
        return await load_balancer.generate(
            modified_request,
            prefill_server,
            decode_server,
            endpoint=endpoint_name,
        )


@app.post("/v1/chat/completions")
async def handle_chat_completion_request(request_data: dict):
    return await _forward_to_backend(request_data, "v1/chat/completions")


@app.post("/v1/completions")
async def handle_completion_request(request_data: dict):
    return await _forward_to_backend(request_data, "v1/completions")


def _generate_bootstrap_room():
    return random.randint(0, 2**63 - 1)


# We may utilize `GenerateReqInput`'s logic later
def _get_request_batch_size(request):
    if (text := request.get("text")) is not None:
        return None if isinstance(text, str) else len(text)
    if (input_ids := request.get("input_ids")) is not None:
        return None if isinstance(input_ids[0], int) else len(input_ids)
    return None


@app.get("/v1/models")
async def get_models():
    prefill_server = load_balancer.prefill_servers[0]  # Get the first prefill server

    # Create connector with increased connection limits
    connector = aiohttp.TCPConnector(
        limit=load_balancer.connector_limit,  # Total connection pool size
        limit_per_host=load_balancer.connector_limit_per_host,  # Connections per host
        force_close=False,
        enable_cleanup_closed=True,
    )

    async with aiohttp.ClientSession(connector=connector) as session:
        try:
            response = await session.get(f"{prefill_server}/v1/models")
            if response.status != 200:
                raise HTTPException(
                    status_code=response.status,
                    detail=f"Prefill server error: Status {response.status}",
                )
            return ORJSONResponse(content=await response.json())
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))


# @app.post("/register")
# async def register(obj: PDRegistryRequest):
#     if obj.mode == "prefill":
#         load_balancer.add_prefill_server(
#             PrefillConfig(obj.registry_url, obj.bootstrap_port)
#         )
#         logger.info(
#             f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
#         )
#     elif obj.mode == "decode":
#         load_balancer.add_decode_server(obj.registry_url)
#         logger.info(f"Registered decode server: {obj.registry_url}")
#     else:
#         raise HTTPException(
#             status_code=400,
#             detail="Invalid mode. Must be either PREFILL or DECODE.",
#         )
#
#     logger.info(
#         f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
#         f"#Decode servers: {len(load_balancer.decode_servers)}"
#     )
#
#     return Response(status_code=200)


def run(prefill_configs, decode_addrs, host, port, connector_limit=10000, connector_limit_per_host=1000):
    # Increase file descriptor limits for high-concurrency
    increase_file_descriptor_limits()

    global load_balancer
    load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs, connector_limit, connector_limit_per_host)
    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    # FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
    # from sglang.srt.disaggregation.launch_lb import main
    from launch_lb import main

    main()

总结

PD 分离最佳配比实验方法的核心思路:

  1. 独立测试:分别测试 Prefill 和 Decode 的性能,找出各自的最佳配置
  2. 配比匹配:根据 RPS 能力计算最优的 PD 配比
  3. 端到端验证:在实际配置下进行端到端测试,验证整体性能
  4. 注意细节:关注超时、句柄数、测试工具等细节,避免踩坑

通过这套方法,可以系统地找到满足 SLO 的最优 PD 分离配比。