投机采样

投机采样
gogongxt论文:
- Medusa https://arxiv.org/pdf/2401.10774
- EAGLE1 https://arxiv.org/pdf/2401.15077
- EAGLE2 https://arxiv.org/pdf/2406.16858
- EAGLE3 https://arxiv.org/pdf/2503.01840
Medusa
对Medusa架构来说,是在原有的LM Head单头基础上,增加几个头,预测后面的token
1. 什么是 Medusa 的“多头”?
在传统的 LLM 中,网络的最顶端只有一个 LM
Head(语言模型头)。它的任务很单一:接收最后一层的隐含层特征
Medusa 的“多头” (Multiple Heads)
,就是在这个主干网络的最后一层
- Original LM Head:负责预测第
个 Token。 - Medusa Head 1:直接预测第
个 Token。 - Medusa Head 2:直接预测第
个 Token。 - 以此类推。
2. Medusa 头的架构是什么样的?
Medusa 的头设计得非常轻量,目的是不增加太多的额外计算开销。它的架构通常是一个简单的残差块(ResBlock)。
假设 Target 模型的最后一层隐含层输出是
特征变换(ResBlock):将
通过一个带激活函数的单层线性映射,并加上残差连接: (其中
是这个头专属的权重矩阵,尺寸通常也是 )复用 LM Head 权重:拿到变换后的特征
后,Medusa 不会自己去乘以一个庞大的词表矩阵,而是直接复用 Target 模型自带的原始 LM Head 的权重矩阵 来输出概率分布:
本质上,每个 Medusa Head 就是一个单层的
MLP,它们都在努力学习一种“跳跃式”的映射关系:仅仅凭借
EAGLE1
you asked
讲一讲EAGLE1算法,是把target模型的最后一层的lm head之前的隐含层拿出来,和token的embed的输出,做了contat,然后再给到draft模型吗 我不是很了解EAGLE算法,所以需要你给我指正,讲明白这个操作流程
gemini response
你的直觉非常敏锐!你描述的核心思想基本是完全正确的。
EAGLE(Extrapolation Algorithm for Greater Language-model Efficiency)算法的核心创新点,正是复用 Target 模型(大模型)的高维特征,将其与 Token Embedding 结合,交给一个轻量级的 Draft 模型进行“特征外推”(Feature Extrapolation)。
为了让你把这个操作流程完全理顺,我们需要在你的理解上做一个时间步(Time Step)上的精确对齐。
以下是 EAGLE1 算法的详细操作流程剖析:
1. 核心变量定义
为了说得更明白,我们先定义几个简单的符号:
- Target 模型: 庞大且缓慢的主模型(比如 LLaMA-3-70B)。
: Target 模型的 LM Head(负责把隐含层特征变成词表的概率分布)。 : Target 模型的 Embedding 层(负责把 Token ID 变成词向量)。 : Target 模型在第 步时,LM Head 之前的最后一层隐含层输出(Hidden State)。 : 第 步生成的 Token ID。- Draft 模型(Auto-regressive Head): EAGLE 引入的轻量级网络(通常只有一个 Transformer 层外加一些线性映射层)。
2. 标准的 Target 模型前向传播(作为起点)
假设我们现在处于第
- Target 模型输出了当前步的隐含层特征
。 - 经过 LM Head 预测出下一个 Token:
。
到这里,我们手头有了两样关键资产:上一层的深层语义特征
3. EAGLE 的核心:Draft 模型的拼接与外推
现在,我们要开始“猜”(Draft)第
- 获取 Embedding: 拿到刚刚预测出的新 Token
的 Embedding 输出,即 。 - 特征拼接(Concat / 融合): 就像你说的,把 Target
模型上一步的隐含层
,和当前新 Token 的 Embedding 拿过来。EAGLE 会将这两者进行拼接(并通常会通过一个线性层降维整合)。 - Draft 模型预测“未来的隐含层”:
把拼接后的结果喂给轻量级的 Draft 模型。
_注意:Draft 模型的输出不是概率分布,而是预测出的“下一步的隐含层特征” 。_ - 复用 LM Head 得到 Draft Token: 拿着预测出的
,直接扔进 Target 模型自带的、已经冻结的 LM Head 里,得到下一个 Draft Token:
4. 自回归生成更长的 Draft
如果我们需要一次性猜 3 个 Token,EAGLE 会继续它的自回归循环:
- 拿刚刚生成的
和 拼接。 - 喂给 Draft 模型,得到
。 - 经过 LM Head 得到
。
循环往复,直到生成足够长度的 Draft 序列。最后,把这个序列交给 Target 模型去做 Tree Attention 进行并行验证。
总结:为什么 EAGLE 要这么设计?
传统的推测解码(比如用一个独立的 68M 小模型去猜 7B
大模型)存在一个巨大的浪费:Target
模型辛辛苦苦算出来的深层上下文特征
EAGLE 的直觉是:Target 模型的
因为只是做特征层面的“一步推演”(Extrapolation),这个任务比“从头理解语言”要简单得多,所以 EAGLE 的 Draft 模型可以做得极小(参数量常常只有 Target 模型的百分之几),同时依然保持极高的接收率(Acceptance Rate)。
you asked
那么在draft的层中,也有kvcache的概念吗,draftmodel的输入是target的隐藏层和新token的embed向量,对于一个请求来说就是(1,n) + (1,n),然后concat后就是(1,2n),然后经过fc又变成了(1,n)是这样吗,然后这里的计算要去做完整的attention的时候,会和历史的token有kvcache的过程吗
gemini response
你的理解非常精准!你对维度变换的推演完全正确,而且你极其敏锐地抓住了大模型自回归生成的核心命脉——KV Cache。
针对你的问题,我分两个部分给你明确的解答和梳理:
1. 维度的变换过程:完全正确
你描述的输入拼接和降维过程,与 EAGLE
的官方实现是完全一致的。我们用数学维度的视角再过一遍(假设
Target 模型的隐藏层维度是
- Target 隐藏层
:形状是 。 - 新 Token 的 Embedding
:形状也是 。 - Concat (拼接):将两者在特征维度拼接,变成
。 - FC (全连接层 / 线性映射):经过一个线性层,将维度从
重新压缩回 。所以输出形状又变回了 。
为什么一定要变回
2. Draft 模型中的 KV Cache 机制:是的,不仅有,而且很关键
你的直觉很准:Draft 模型内部绝对有 KV Cache,并且在做完整的 Attention 时,也会和历史 Token 发生交互。
EAGLE 的 Draft 模型虽然很小,但它并不是一个简单的多层感知机(MLP),它通常包含一层 Transformer Decoder 结构。只要有 Transformer 的 Self-Attention 机制,在自回归生成时就离不开 KV Cache。
具体流程是这样的:
A. 历史 Prompt 阶段的 KV Cache (Prefill)
当用户输入一段 Prompt 时,Target 模型在做 Prefill(预填充),Draft 模型也没闲着。 Draft 模型会把前面所有历史 Token 的特征(经过 Concat 和 FC 后的特征向量)送入自己的那一层 Transformer 中,计算出对应的 Key 和 Value,并存入 Draft 模型专属的 KV Cache 中。 注:Target 模型有一个庞大的 KV Cache(比如 32 层),而 Draft 模型有一个极小的 KV Cache(只有 1 层)。两者是物理隔离的。
B. Draft 推理阶段的 KV Cache (Decode)
假设 Target 模型刚输出了第
- 猜第 1 个 Draft Token: Draft 模型拿到
的融合特征,送入自己的 Transformer 层。此时,它会用当前的 Query 去 Attend 整个历史 Prompt 的 Draft KV Cache。算完后,把当前步的 K 和 V 追加进 Draft KV Cache,并输出预测。 - 猜第 2 个 Draft Token: Draft 模型拿到上一步自己的输出和新猜测 Token 的 Embedding,再次融合。送入 Transformer 层时,它不仅会 Attend 历史 Prompt,还会 Attend 刚刚猜出的第 1 个 Draft Token 的 KV Cache。
- 猜第 3 个 Draft Token: 同理,Attend 历史 + 前 2 个 Draft Token。
为什么 Draft 模型还需要 Attention 和历史 KV Cache?
你可能会问:Target 模型的隐藏层
原因在于:虽然
总结
- 维度推演: 你的
逻辑完全正确。 - KV Cache: Draft 模型本质上是一个“单层的小型自回归语言模型”。只要是自回归,就有自己独立的 KV Cache,它必须利用历史 Token 的 KV Cache 来保证推断的准确性。







