sglang-attention

以nano-sglang举例子

下面以llama2-7b来说,apply_chat_template是在开头加上<s>_

输入”12345” 就是 <s>_12345,对应input_ids是[1, 29871, 29896, 29906, 29941, 29946, 29945]

prefill

我们第一遍先发送”12345”,那么[1, 29871, 29896, 29906, 29941, 29946, 29945]就会被cache住

第二遍发送”1234512345”,就是cache了7个token,要extend5个

req.input_ids = [1, 29871, 29896, 29906, 29941, 29946, 29945, 29896, 29906, 29941, 29946, 29945]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

@torch.inference_mode()
def forward_extend(
    self,
    input_ids,
    req_pool_indices,
    seq_lens,
    prefix_lens,
    out_cache_loc,
    return_normalized_logprob,
):
    input_metadata = InputMetadata.create(
        self,
        forward_mode=ForwardMode.EXTEND,
        tp_size=self.tp_size,
        req_pool_indices=req_pool_indices, # [0] 表示请求在req_pool的位置,size是请求数
        seq_lens=seq_lens, # [12] 表示请求的长度,size是请求数
        prefix_lens=prefix_lens, # [7] 表示请求cache tokens个数,size是请求数
        out_cache_loc=out_cache_loc, # [16, 17, 18, 19, 20] 新增加的token的kvcache放的位置,size是batch out tokens个数
        return_normalized_logprob=return_normalized_logprob, # 一般是False,是否需要返回token的logprob
    )
    return self.model.forward(input_ids, input_metadata.positions, input_metadata)

```python
@dataclass
class InputMetadata:
    model_runner: "ModelRunner"
    forward_mode: ForwardMode
    batch_size: int
    total_num_tokens: int
    max_seq_len: int
    req_pool_indices: torch.Tensor
    start_loc: torch.Tensor
    seq_lens: torch.Tensor
    prefix_lens: torch.Tensor
    positions: torch.Tensor
    req_to_token_pool: ReqToTokenPool
    token_to_kv_pool: TokenToKVPool

    # for extend
    extend_seq_lens: torch.Tensor = None
    extend_start_loc: torch.Tensor = None
    max_extend_len: int = 0

    out_cache_loc: torch.Tensor = None
    out_cache_cont_start: torch.Tensor = None
    out_cache_cont_end: torch.Tensor = None

    other_kv_index: torch.Tensor = None
    return_normalized_logprob: bool = False

    def init_extend_args(self):
        self.extend_seq_lens = self.seq_lens - self.prefix_lens
        self.extend_start_loc = torch.zeros_like(self.seq_lens)
        self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], 0)
        self.max_extend_len = int(torch.max(self.extend_seq_lens))

    @classmethod
    def create(
        cls,
        model_runner,
        tp_size,
        forward_mode,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        out_cache_loc,
        out_cache_cont_start=None,
        out_cache_cont_end=None,
        return_normalized_logprob=False,
    ):
        batch_size = len(req_pool_indices) # 1 就是bs
        start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
        start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0) # seqa_lens tokens个数前缀和
        total_num_tokens = int(torch.sum(seq_lens)) # 12 总tokens个数
        max_seq_len = int(torch.max(seq_lens)) # 12 最大值

        if forward_mode == ForwardMode.DECODE:
            positions = (seq_lens - 1).to(torch.int64)
            other_kv_index = model_runner.req_to_token_pool.req_to_token[
                req_pool_indices[0], seq_lens[0] - 1
            ].item()
        else:
            seq_lens_np = seq_lens.cpu().numpy()
            prefix_lens_np = prefix_lens.cpu().numpy()
            positions = torch.tensor(
                np.concatenate(
                    [
                        np.arange(
                            prefix_lens_np[i],
                            seq_lens_np[i],
                        )
                        for i in range(batch_size)
                    ],
                    axis=0,
                ),
                device="cuda",
            ) # [7,8,9,10,11] positions用于后面推理时的位置编码
            other_kv_index = None

        ret = cls(
            model_runner=model_runner,
            forward_mode=forward_mode,
            batch_size=batch_size,
            total_num_tokens=total_num_tokens,
            max_seq_len=max_seq_len,
            req_pool_indices=req_pool_indices,
            start_loc=start_loc,
            seq_lens=seq_lens,
            prefix_lens=prefix_lens,
            positions=positions,
            req_to_token_pool=model_runner.req_to_token_pool,
            token_to_kv_pool=model_runner.token_to_kv_pool,
            out_cache_loc=out_cache_loc,
            out_cache_cont_start=out_cache_cont_start,
            out_cache_cont_end=out_cache_cont_end,
            return_normalized_logprob=return_normalized_logprob,
            other_kv_index=other_kv_index,
        )

        if forward_mode == ForwardMode.EXTEND:
            ret.init_extend_args()

        ret.use_flashinfer = "flashinfer" in model_runner.model_mode
        if ret.use_flashinfer:
            ret.init_flashinfer_args(tp_size)

        return ret