cs336-02-train

cs336-02-train
gogongxtTransformer 语言模型训练细节详解
本文档详细介绍 CS336 Assignment 1 中训练流程的核心细节,包括数据批处理、模型前向传播、损失函数、梯度裁剪和困惑度计算。
1. 训练流程总览
训练循环位于 train.py:545-593,核心逻辑如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
for step in range(start_step, config.total_steps):
# 1. 获取当前学习率 (余弦退火 + warmup)
lr = run_get_lr_cosine_schedule(step, learning_rate, min_lr, warmup_steps, total_steps)
# 2. 更新优化器学习率
for param_group in optimizer.param_groups:
param_group["lr"] = lr
# 3. 获取一个 batch 的数据
x, y = run_get_batch(train_data, batch_size, context_length, device)
# 4. 前向传播
logits = model(x)
# 5. 计算损失
loss = run_cross_entropy(logits.view(-1, vocab_size), y.view(-1))
# 6. 反向传播
optimizer.zero_grad()
loss.backward()
# 7. 梯度裁剪
run_gradient_clipping(model.parameters(), grad_clip)
# 8. 参数更新
optimizer.step()2. 批数据获取 (Batch Data Sampling)
2.1 数据来源
数据以 memory-mapped numpy array 形式加载,可以高效处理大文件:
1
2
3
4
5
def load_dataset(data_path: str) -> np.ndarray:
if data_path.endswith(".npy"):
return np.load(data_path, mmap_mode="r")
elif data_path.endswith(".bin"):
return np.memmap(data_path, dtype=dtype, mode="r")2.2 Batch 获取逻辑
核心函数 run_get_batch 实现如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def run_get_batch(dataset, batch_size, context_length, device):
# 最大起始索引:需要 context_length + 1 个 token (输入 + 目标)
max_start_idx = len(dataset) - context_length - 1
# 随机生成 batch_size 个起始位置
start_indices = np.random.randint(0, max_start_idx + 1, size=batch_size)
# 构建输入 x 和标签 y
# x[i] = dataset[start_idx : start_idx + context_length]
# y[i] = dataset[start_idx + 1 : start_idx + context_length + 1]
x = np.empty((batch_size, context_length), dtype=dataset.dtype)
y = np.empty((batch_size, context_length), dtype=dataset.dtype)
return torch.from_numpy(x), torch.from_numpy(y)2.3 数据维度
假设配置:batch_size=64,
context_length=256
1
2
x shape: [64, 256] # 输入 token IDs
y shape: [64, 256] # 目标 token IDs (右移一位)2.4 数据采样特点
重要特性:当前实现 不处理特殊终止符 (EOS) :
- 起始位置是全局随机选择的,可能落在任何位置
- 可能跨越文档边界
- EOS token 后面的内容也会被当作正常训练数据
这种设计的优缺点:
- ✅ 数据利用率 100%
- ✅ 实现简单
- ❌ 模型可能学习在 EOS 后继续生成
3. 模型前向传播
3.1 输入输出维度
1
logits = model(x) # x: [batch, seq_len]| 变量 | 形状 | 说明 |
|---|---|---|
x (输入) |
[batch_size, context_length] |
输入 token IDs |
logits (输出) |
[batch_size, context_length, vocab_size] |
每个位置的词汇表分数 |
3.2 为什么输出不是
[batch, 1, vocab]?
关键设计:训练时使用整个序列进行预测,而不是只预测最后一个 token。
1
2
3
4
5
6
7
8
9
输入 x: [t0, t1, t2, t3, t4, t5, t6, t7]
↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
标签 y: [t1, t2, t3, t4, t5, t6, t7, t8]
logits: [batch, 8, vocab]
位置0 预测 token1
位置1 预测 token2
...
位置7 预测 token8优势:
- 数据利用率高:每个 token 都提供监督信号
- 训练更稳定:梯度来自多个位置的 loss 平均
- 建模能力强:模型学会在任何位置预测
3.3 推理时的使用
推理时只取最后一个位置的输出,其实也就是prefill,在本文不涉及推理,这里只是简单扩展一下:
1
2
3
logits = model(x) # [1, 256, vocab]
next_token_logits = logits[0, -1, :] # 只取最后一个位置
next_token = sample(next_token_logits)4. 交叉熵损失函数
4.1 实现细节
run_cross_entropy 手动实现了数值稳定的交叉熵计算:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def run_cross_entropy(inputs, targets):
"""
inputs: [batch_size, vocab_size] - 未归一化的 logit
targets: [batch_size] - 目标类别索引
"""
# 数值稳定技巧:减去最大值避免 exp 溢出
max_logits = inputs.max(dim=-1, keepdim=True).values
shifted_logits = inputs - max_logits
# log(sum(exp(logits))) - 分母
log_sum_exp = shifted_logits.exp().sum(dim=-1).log()
# 取出目标类别的 logit - 分子
target_logits = shifted_logits.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)
# 交叉熵公式: -log(softmax) = -target + log(sum_exp)
loss_per_example = -target_logits + log_sum_exp
return loss_per_example.mean()4.2 数学原理
交叉熵损失 (Cross-Entropy Loss):
展开为:
数值稳定性处理:
- 直接计算
可能溢出 - 减去最大值
后再计算:
4.3 训练中的使用
1
2
3
4
5
6
# logits: [batch, seq, vocab] → 展平为 [batch*seq, vocab]
# y: [batch, seq] → 展平为 [batch*seq]
logits_flat = logits.view(-1, vocab_size) # [16384, 10000]
y_flat = y.view(-1) # [16384]
loss = run_cross_entropy(logits_flat, y_flat)5. 梯度裁剪 (Gradient Clipping)
5.1 什么是梯度裁剪?
梯度裁剪是一种防止梯度爆炸 (gradient explosion) 的技术。当梯度范数超过阈值时,按比例缩放梯度。
5.2 实现代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def run_gradient_clipping(parameters, max_l2_norm):
# 1. 计算所有参数梯度的总 L2 范数
total_norm = 0.0
for param in parameters:
if param.grad is not None:
total_norm += param.grad.data.pow(2).sum()
total_norm = total_norm.sqrt().item()
# 2. 如果超过阈值,进行裁剪
if total_norm > max_l2_norm:
clip_coef = max_l2_norm / (total_norm + 1e-6)
for param in parameters:
if param.grad is not None:
param.grad.data.mul_(clip_coef)5.3 数学原理
5.4 目的
| 问题 | 解决方案 |
|---|---|
| 梯度爆炸导致训练不稳定 | 限制最大梯度范数 |
| 梯度消失导致训练停滞 | 使用残差连接、归一化层等 |
| 数值溢出 | 梯度裁剪 + 数值稳定技巧 |
6. 学习率调度 (Cosine Annealing with Warmup)
6.1 什么是学习率调度?
学习率调度 (Learning Rate Scheduling) 是在训练过程中动态调整学习率的技术。本实现使用 余弦退火 + 线性 Warmup 策略。
6.2 实现代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def run_get_lr_cosine_schedule(
it, # 当前迭代步
max_learning_rate, # 最大学习率 (α_max)
min_learning_rate, # 最小学习率 (α_min)
warmup_iters, # Warmup 步数 (T_w)
cosine_cycle_iters, # 余弦周期总步数 (T_c)
):
# Phase 1: 线性 Warmup (0 ~ T_w)
if it < warmup_iters:
return max_learning_rate * (it / warmup_iters)
# Phase 2: 余弦退火 (T_w ~ T_c)
if it < cosine_cycle_iters:
progress = (it - warmup_iters) / (cosine_cycle_iters - warmup_iters)
return min_learning_rate + 0.5 * (max_learning_rate - min_learning_rate) * (
1 + math.cos(math.pi * progress)
)
# Phase 3: 保持最小学习率 (T_c ~ ∞)
return min_learning_rate6.3 学习率变化曲线
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
LR
↑
α_max ────────────────────────
╱ ╲
╱ ╲
╱ ╲
╱ ╲
╱ ╲
╱ ╲
╱ ╲
╱ ╲
α_min ──────────────────────────────→ Steps
↑ ↑ ↑
0 T_w T_c ∞
Phase 1: 线性 Warmup (从 0 增长到 α_max)
Phase 2: 余弦退火 (从 α_max 衰减到 α_min)
Phase 3: 保持 α_min6.4 数学公式
Phase 1 - 线性 Warmup:
Phase 2 - 余弦退火:
其中
6.5 为什么需要 Warmup?
| 阶段 | 目的 |
|---|---|
| Warmup | 初始参数随机,大的学习率会导致不稳定;Warmup 让模型逐步适应 |
| Cosine 衰减 | 后期小学习率精细调优,避免在最优解附近震荡 |
| 保持最小值 | 保持微小的探索能力,防止完全收敛到局部最优 |
6.6 训练配置示例
1
2
3
4
5
6
config = TrainingConfig(
learning_rate=3e-4, # α_max = 0.0003
min_learning_rate=3e-5, # α_min = 0.00003
warmup_steps=500, # T_w = 500
total_steps=5000, # T_c = 5000
)7. 困惑度 (Perplexity)
7.1 定义
困惑度 (Perplexity, PPL) 是语言模型性能的标准度量,表示模型对下一个 token 的平均不确定性。
7.2 计算
1
2
3
4
5
6
def compute_perplexity(loss: float) -> float:
return math.exp(loss)
# 在训练日志中
avg_loss = sum(train_losses[-log_interval:]) / log_interval
ppl = compute_perplexity(avg_loss)7.3 解释
| PPL 值 | 含义 |
|---|---|
| 1.0 | 完美预测,模型完全确定 |
| 10 | 模型在 10 个 token 中选 1 个 |
| 100 | 模型在 100 个 token 中选 1 个 |
| → ∞ | 模型完全随机猜测 |
8. 总结
本文档介绍了 Transformer 语言模型训练的核心组件:
- 批数据获取:随机采样,使用整个序列而非单点预测
- 模型输出:每个位置预测下一个 token,维度
[batch, seq, vocab] - 损失函数:数值稳定的交叉熵实现
- 梯度裁剪:防止梯度爆炸,保持训练稳定
- 困惑度:衡量模型预测能力的指数度量
这些组件共同构成了现代语言模型训练的基础。
评论
匿名评论隐私政策




