sglang-attention

sglang-attention
gogongxt以nano-sglang举例子
下面以llama2-7b来说,apply_chat_template是在开头加上<s>_
输入”12345” 就是 <s>_12345,对应input_ids是[1,
29871, 29896, 29906, 29941, 29946, 29945]
prefill
我们第一遍先发送”12345”,那么[1, 29871, 29896, 29906, 29941, 29946, 29945]就会被cache住
第二遍发送”1234512345”,就是cache了7个token,要extend5个
req.input_ids = [1, 29871, 29896, 29906, 29941, 29946, 29945, 29896, 29906, 29941, 29946, 29945]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
@torch.inference_mode()
def forward_extend(
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
out_cache_loc,
return_normalized_logprob,
):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices, # [0] 表示请求在req_pool的位置,size是请求数
seq_lens=seq_lens, # [12] 表示请求的长度,size是请求数
prefix_lens=prefix_lens, # [7] 表示请求cache tokens个数,size是请求数
out_cache_loc=out_cache_loc, # [16, 17, 18, 19, 20] 新增加的token的kvcache放的位置,size是batch out tokens个数
return_normalized_logprob=return_normalized_logprob, # 一般是False,是否需要返回token的logprob
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
```python
@dataclass
class InputMetadata:
model_runner: "ModelRunner"
forward_mode: ForwardMode
batch_size: int
total_num_tokens: int
max_seq_len: int
req_pool_indices: torch.Tensor
start_loc: torch.Tensor
seq_lens: torch.Tensor
prefix_lens: torch.Tensor
positions: torch.Tensor
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
# for extend
extend_seq_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None
max_extend_len: int = 0
out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None
return_normalized_logprob: bool = False
def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], 0)
self.max_extend_len = int(torch.max(self.extend_seq_lens))
@classmethod
def create(
cls,
model_runner,
tp_size,
forward_mode,
req_pool_indices,
seq_lens,
prefix_lens,
out_cache_loc,
out_cache_cont_start=None,
out_cache_cont_end=None,
return_normalized_logprob=False,
):
batch_size = len(req_pool_indices) # 1 就是bs
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) # seqa_lens tokens个数前缀和
total_num_tokens = int(torch.sum(seq_lens)) # 12 总tokens个数
max_seq_len = int(torch.max(seq_lens)) # 12 最大值
if forward_mode == ForwardMode.DECODE:
positions = (seq_lens - 1).to(torch.int64)
other_kv_index = model_runner.req_to_token_pool.req_to_token[
req_pool_indices[0], seq_lens[0] - 1
].item()
else:
seq_lens_np = seq_lens.cpu().numpy()
prefix_lens_np = prefix_lens.cpu().numpy()
positions = torch.tensor(
np.concatenate(
[
np.arange(
prefix_lens_np[i],
seq_lens_np[i],
)
for i in range(batch_size)
],
axis=0,
),
device="cuda",
) # [7,8,9,10,11] positions用于后面推理时的位置编码
other_kv_index = None
ret = cls(
model_runner=model_runner,
forward_mode=forward_mode,
batch_size=batch_size,
total_num_tokens=total_num_tokens,
max_seq_len=max_seq_len,
req_pool_indices=req_pool_indices,
start_loc=start_loc,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
positions=positions,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
return_normalized_logprob=return_normalized_logprob,
other_kv_index=other_kv_index,
)
if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()
ret.use_flashinfer = "flashinfer" in model_runner.model_mode
if ret.use_flashinfer:
ret.init_flashinfer_args(tp_size)
return ret 评论
匿名评论隐私政策






