cs336-02-train

Transformer 语言模型训练细节详解

本文档详细介绍 CS336 Assignment 1 中训练流程的核心细节,包括数据批处理、模型前向传播、损失函数、梯度裁剪和困惑度计算。


1. 训练流程总览

训练循环位于 train.py:545-593,核心逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
for step in range(start_step, config.total_steps):
    # 1. 获取当前学习率 (余弦退火 + warmup)
    lr = run_get_lr_cosine_schedule(step, learning_rate, min_lr, warmup_steps, total_steps)

    # 2. 更新优化器学习率
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    # 3. 获取一个 batch 的数据
    x, y = run_get_batch(train_data, batch_size, context_length, device)

    # 4. 前向传播
    logits = model(x)

    # 5. 计算损失
    loss = run_cross_entropy(logits.view(-1, vocab_size), y.view(-1))

    # 6. 反向传播
    optimizer.zero_grad()
    loss.backward()

    # 7. 梯度裁剪
    run_gradient_clipping(model.parameters(), grad_clip)

    # 8. 参数更新
    optimizer.step()

2. 批数据获取 (Batch Data Sampling)

2.1 数据来源

数据以 memory-mapped numpy array 形式加载,可以高效处理大文件:

1
2
3
4
5
def load_dataset(data_path: str) -> np.ndarray:
    if data_path.endswith(".npy"):
        return np.load(data_path, mmap_mode="r")
    elif data_path.endswith(".bin"):
        return np.memmap(data_path, dtype=dtype, mode="r")

2.2 Batch 获取逻辑

核心函数 run_get_batch 实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def run_get_batch(dataset, batch_size, context_length, device):
    # 最大起始索引:需要 context_length + 1 个 token (输入 + 目标)
    max_start_idx = len(dataset) - context_length - 1

    # 随机生成 batch_size 个起始位置
    start_indices = np.random.randint(0, max_start_idx + 1, size=batch_size)

    # 构建输入 x 和标签 y
    # x[i] = dataset[start_idx : start_idx + context_length]
    # y[i] = dataset[start_idx + 1 : start_idx + context_length + 1]
    x = np.empty((batch_size, context_length), dtype=dataset.dtype)
    y = np.empty((batch_size, context_length), dtype=dataset.dtype)

    return torch.from_numpy(x), torch.from_numpy(y)

2.3 数据维度

假设配置:batch_size=64, context_length=256

1
2
x shape: [64, 256]   # 输入 token IDs
y shape: [64, 256]   # 目标 token IDs (右移一位)

2.4 数据采样特点

重要特性:当前实现 不处理特殊终止符 (EOS)

  • 起始位置是全局随机选择的,可能落在任何位置
  • 可能跨越文档边界
  • EOS token 后面的内容也会被当作正常训练数据

这种设计的优缺点:

  • ✅ 数据利用率 100%
  • ✅ 实现简单
  • ❌ 模型可能学习在 EOS 后继续生成

3. 模型前向传播

3.1 输入输出维度

1
logits = model(x)  # x: [batch, seq_len]
变量 形状 说明
x (输入) [batch_size, context_length] 输入 token IDs
logits (输出) [batch_size, context_length, vocab_size] 每个位置的词汇表分数

3.2 为什么输出不是 [batch, 1, vocab]

关键设计:训练时使用整个序列进行预测,而不是只预测最后一个 token。

1
2
3
4
5
6
7
8
9
输入 x:   [t0, t1, t2, t3, t4, t5, t6, t7]
            ↓   ↓   ↓   ↓   ↓   ↓   ↓   ↓
标签 y:   [t1, t2, t3, t4, t5, t6, t7, t8]

 logits:  [batch, 8, vocab]
          位置0 预测 token1
          位置1 预测 token2
          ...
          位置7 预测 token8

优势

  1. 数据利用率高:每个 token 都提供监督信号
  2. 训练更稳定:梯度来自多个位置的 loss 平均
  3. 建模能力强:模型学会在任何位置预测

3.3 推理时的使用

推理时只取最后一个位置的输出,其实也就是prefill,在本文不涉及推理,这里只是简单扩展一下:

1
2
3
logits = model(x)                    # [1, 256, vocab]
next_token_logits = logits[0, -1, :] # 只取最后一个位置
next_token = sample(next_token_logits)

4. 交叉熵损失函数

4.1 实现细节

run_cross_entropy 手动实现了数值稳定的交叉熵计算:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def run_cross_entropy(inputs, targets):
    """
    inputs:  [batch_size, vocab_size] - 未归一化的 logit
    targets: [batch_size] - 目标类别索引
    """

    # 数值稳定技巧:减去最大值避免 exp 溢出
    max_logits = inputs.max(dim=-1, keepdim=True).values
    shifted_logits = inputs - max_logits

    # log(sum(exp(logits))) - 分母
    log_sum_exp = shifted_logits.exp().sum(dim=-1).log()

    # 取出目标类别的 logit - 分子
    target_logits = shifted_logits.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)

    # 交叉熵公式: -log(softmax) = -target + log(sum_exp)
    loss_per_example = -target_logits + log_sum_exp

    return loss_per_example.mean()

4.2 数学原理

交叉熵损失 (Cross-Entropy Loss):

展开为:

数值稳定性处理:

  • 直接计算 可能溢出
  • 减去最大值 后再计算:

4.3 训练中的使用

1
2
3
4
5
6
# logits: [batch, seq, vocab] → 展平为 [batch*seq, vocab]
# y:      [batch, seq]       → 展平为 [batch*seq]
logits_flat = logits.view(-1, vocab_size)  # [16384, 10000]
y_flat = y.view(-1)                         # [16384]

loss = run_cross_entropy(logits_flat, y_flat)

5. 梯度裁剪 (Gradient Clipping)

5.1 什么是梯度裁剪?

梯度裁剪是一种防止梯度爆炸 (gradient explosion) 的技术。当梯度范数超过阈值时,按比例缩放梯度。

5.2 实现代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def run_gradient_clipping(parameters, max_l2_norm):
    # 1. 计算所有参数梯度的总 L2 范数
    total_norm = 0.0
    for param in parameters:
        if param.grad is not None:
            total_norm += param.grad.data.pow(2).sum()
    total_norm = total_norm.sqrt().item()

    # 2. 如果超过阈值,进行裁剪
    if total_norm > max_l2_norm:
        clip_coef = max_l2_norm / (total_norm + 1e-6)
        for param in parameters:
            if param.grad is not None:
                param.grad.data.mul_(clip_coef)

5.3 数学原理

5.4 目的

问题 解决方案
梯度爆炸导致训练不稳定 限制最大梯度范数
梯度消失导致训练停滞 使用残差连接、归一化层等
数值溢出 梯度裁剪 + 数值稳定技巧

6. 学习率调度 (Cosine Annealing with Warmup)

6.1 什么是学习率调度?

学习率调度 (Learning Rate Scheduling) 是在训练过程中动态调整学习率的技术。本实现使用 余弦退火 + 线性 Warmup 策略。

6.2 实现代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def run_get_lr_cosine_schedule(
    it,                    # 当前迭代步
    max_learning_rate,     # 最大学习率 (α_max)
    min_learning_rate,     # 最小学习率 (α_min)
    warmup_iters,          # Warmup 步数 (T_w)
    cosine_cycle_iters,    # 余弦周期总步数 (T_c)
):
    # Phase 1: 线性 Warmup (0 ~ T_w)
    if it < warmup_iters:
        return max_learning_rate * (it / warmup_iters)

    # Phase 2: 余弦退火 (T_w ~ T_c)
    if it < cosine_cycle_iters:
        progress = (it - warmup_iters) / (cosine_cycle_iters - warmup_iters)
        return min_learning_rate + 0.5 * (max_learning_rate - min_learning_rate) * (
            1 + math.cos(math.pi * progress)
        )

    # Phase 3: 保持最小学习率 (T_c ~ ∞)
    return min_learning_rate

6.3 学习率变化曲线

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
LR

α_max ────────────────────────
       ╱ ╲
       ╱   ╲
       ╱     ╲
       ╱       ╲
       ╱         ╲
       ╱           ╲
       ╱             ╲
       ╱               ╲
α_min ──────────────────────────────→ Steps
      ↑                ↑           ↑
      0            T_w    T_c    ∞

Phase 1: 线性 Warmup (从 0 增长到 α_max)
Phase 2: 余弦退火 (从 α_max 衰减到 α_min)
Phase 3: 保持 α_min

6.4 数学公式

Phase 1 - 线性 Warmup:

Phase 2 - 余弦退火:

其中

6.5 为什么需要 Warmup?

阶段 目的
Warmup 初始参数随机,大的学习率会导致不稳定;Warmup 让模型逐步适应
Cosine 衰减 后期小学习率精细调优,避免在最优解附近震荡
保持最小值 保持微小的探索能力,防止完全收敛到局部最优

6.6 训练配置示例

1
2
3
4
5
6
config = TrainingConfig(
    learning_rate=3e-4,      # α_max = 0.0003
    min_learning_rate=3e-5,  # α_min = 0.00003
    warmup_steps=500,        # T_w = 500
    total_steps=5000,        # T_c = 5000
)

7. 困惑度 (Perplexity)

7.1 定义

困惑度 (Perplexity, PPL) 是语言模型性能的标准度量,表示模型对下一个 token 的平均不确定性

7.2 计算

1
2
3
4
5
6
def compute_perplexity(loss: float) -> float:
    return math.exp(loss)

# 在训练日志中
avg_loss = sum(train_losses[-log_interval:]) / log_interval
ppl = compute_perplexity(avg_loss)

7.3 解释

PPL 值 含义
1.0 完美预测,模型完全确定
10 模型在 10 个 token 中选 1 个
100 模型在 100 个 token 中选 1 个
→ ∞ 模型完全随机猜测

8. 总结

本文档介绍了 Transformer 语言模型训练的核心组件:

  1. 批数据获取:随机采样,使用整个序列而非单点预测
  2. 模型输出:每个位置预测下一个 token,维度 [batch, seq, vocab]
  3. 损失函数:数值稳定的交叉熵实现
  4. 梯度裁剪:防止梯度爆炸,保持训练稳定
  5. 困惑度:衡量模型预测能力的指数度量

这些组件共同构成了现代语言模型训练的基础。