Transformer 解剖:从 Attention 到推理系统

第 9 章 搭一个 mini-GPT:从零训练能写古诗的小模型

作者 杨艺韬 · 4,469 字

第 9 章 搭一个 mini-GPT:从零训练能写古诗的小模型

第 8 章我们写了一个 production-grade 的 Self-Attention。这章我们把它装进一个完整的 Decoder-only Transformer,加上 token embedding、FFN、RMSNorm、训练循环、生成采样——从「能写一段 attention」走到「能训一个能用的小语言模型」。

为什么选「写古诗」做任务?三个原因:

  1. 数据集小而干净:《全唐诗》大约 4.3 万首诗、共 200 多万字符。这个量级一台 RTX 3060 / 4060 / Mac Pro 都能在几小时内训完。
  2. 结构性强:古诗有格律——四言、五言、七言、平仄、押韵。模型如果学到了这些结构,输出会立刻可识别地「像古诗」,给读者直观的成就感。
  3. 字符级训练就够了:中文古诗一字一意,可以直接用字符级 tokenizer,不需要 BPE。这让我们把焦点放在 Transformer 本身,而不是被 tokenizer 工程分心。

读完这章你能:

9.1 完整模型架构概览

我们要搭的 mini-GPT 长这样:

flowchart TB
  ID["输入 token id<br/>(B, T)"] --> EMB["Token Embedding<br/>(V, d_model)"]
  EMB --> BLK1[Transformer Block 1]
  BLK1 --> BLK2[Transformer Block 2]
  BLK2 --> BLKD[...]
  BLKD --> BLKL[Transformer Block L]
  BLKL --> RMS[Final RMSNorm]
  RMS --> HEAD["LM Head<br/>(d_model, V)"]
  HEAD --> LOG["logits<br/>(B, T, V)"]

  subgraph blkdetail ["Block 内部(第 5 章数据流)"]
    direction TB
    X[x] --> N1[RMSNorm]
    N1 --> ATT["Causal Self-Attention<br/>第 8 章"]
    ATT --> A1["+ residual"]
    X --> A1
    A1 --> N2[RMSNorm]
    N2 --> FFN[SwiGLU FFN]
    FFN --> A2["+ residual"]
    A1 --> A2
  end

参数估算:选 d_model=384, n_heads=6, n_layers=6, d_ffn=1024, vocab_size=8000

10M 量级的模型,单机训练 1-2 小时即可收敛。

9.2 数据准备:从《全唐诗》到张量

数据来源

公开的《全唐诗》数据可以从这几个地方获取(请遵守各自许可):

下载后得到一个大文本文件 tang_poems.txt,每行一首诗:

红豆生南国,春来发几枝。愿君多采撷,此物最相思。
床前明月光,疑是地上霜。举头望明月,低头思故乡。
...

字符级 tokenizer

对中文古诗,最简单的 tokenizer 就是「每个汉字 / 标点 / 换行符 = 一个 token」。这种 tokenizer 不需要训练,直接基于全文档统计:

def build_vocab(text):
    """
    从文本统计字符表,返回 (字 → id) 和 (id → 字) 两个字典
    """
    # 字符按出现频率排序,让常用字在前(id 小)—— 这只是好习惯,不影响模型
    from collections import Counter
    counter = Counter(text)
    chars = ['<pad>', '<bos>', '<eos>', '<unk>'] + [c for c, _ in counter.most_common()]
    char2id = {c: i for i, c in enumerate(chars)}
    id2char = {i: c for c, i in char2id.items()}
    return char2id, id2char

with open('tang_poems.txt', encoding='utf-8') as f:
    text = f.read()

char2id, id2char = build_vocab(text)
print(f"词表大小: {len(char2id)}")  # 大约 6000-8000 个字符

注意 4 个特殊 token:

把文本编码成 token id 序列

def encode(text):
    return [char2id.get(c, char2id['<unk>']) for c in text]

# 把每首诗包上 <bos> 和 <eos>,然后拼起来
poems = text.split('\n')
all_ids = []
for poem in poems:
    poem = poem.strip()
    if not poem:
        continue
    all_ids.extend([char2id['<bos>']] + encode(poem) + [char2id['<eos>']])

print(f"训练 token 总数: {len(all_ids)}")  # 大概 200-300 万

切成训练 batch

CLM 训练需要把长 token 流切成 (B, T) 的小段。最常见的方法是「随机采样」:

import torch

class PoemDataset(torch.utils.data.Dataset):
    def __init__(self, ids, block_size=128):
        self.ids = torch.tensor(ids, dtype=torch.long)
        self.block_size = block_size

    def __len__(self):
        return len(self.ids) - self.block_size - 1

    def __getitem__(self, idx):
        x = self.ids[idx : idx + self.block_size]
        y = self.ids[idx + 1 : idx + self.block_size + 1]
        return x, y  # x: 输入 token, y: 下一 token (shifted by 1)

__getitem__ 返回的 (x, y) 满足:y[i] = x[i+1]——这正是 CLM 的训练形式(位置 i 的 input 对应预测下一位置 y 的 token)。

dataset = PoemDataset(all_ids, block_size=128)
loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)

9.3 Transformer Block

把第 8 章的 CausalSelfAttention 加上 SwiGLU FFN 和 RMSNorm,得到一个完整的 Block:

class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d))
        self.eps = eps

    def forward(self, x):
        # x: (B, T, D)
        rms = x.pow(2).mean(dim=-1, keepdim=True).add_(self.eps).sqrt_()
        return self.gamma * (x / rms)


class SwiGLU(nn.Module):
    def __init__(self, d_model, d_ffn):
        super().__init__()
        self.w = nn.Linear(d_model, d_ffn, bias=False)
        self.v = nn.Linear(d_model, d_ffn, bias=False)
        self.w2 = nn.Linear(d_ffn, d_model, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w(x)) * self.v(x))   # silu == swish


class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ffn, max_seq_len):
        super().__init__()
        self.norm1 = RMSNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, max_seq_len)   # 第 8 章
        self.norm2 = RMSNorm(d_model)
        self.ffn = SwiGLU(d_model, d_ffn)

    def forward(self, x):
        # Pre-LN(第 5.4 节),residual 不经过 norm
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

这个 Block 的形状流转:输入 (B, T, D) → 输出 (B, T, D),可以无限堆叠。

9.4 完整 mini-GPT

class MiniGPT(nn.Module):
    def __init__(self, vocab_size, d_model=384, n_heads=6, n_layers=6,
                 d_ffn=1024, max_seq_len=512):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ffn, max_seq_len)
            for _ in range(n_layers)
        ])
        self.norm_f = RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # 共享 embedding 和 lm_head 的权重 (weight tying)
        self.lm_head.weight = self.tok_emb.weight

        # 初始化(GPT-2 风格)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)

    def forward(self, ids, targets=None):
        # ids: (B, T)
        x = self.tok_emb(ids)              # (B, T, D)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm_f(x)
        logits = self.lm_head(x)            # (B, T, V)

        if targets is None:
            return logits

        # 计算 loss
        # logits: (B, T, V);targets: (B, T)
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            targets.view(-1),
            ignore_index=0,                 # 忽略 <pad>
        )
        return logits, loss

值得停下来强调几个细节:

1. weight tyingself.lm_head.weight = self.tok_emb.weight 让输出投影和输入嵌入共享权重。这能省掉 vocab_size * d_model 的参数(在我们的设置下约 3M 参数,占总参数的 18%)。GPT-2、Llama 1/2 都用这个技巧,Llama 3 因为词表扩到 128K 而放弃了共享。

2. 初始化:GPT-2 风格的 std=0.02 正态初始化。LayerNorm / RMSNorm 的 gamma 初始化为 1,bias 为 0。Llama 在更深的网络上还会对 residual 后的 weights 做额外缩放(std=0.02 / sqrt(2 * n_layers)),让深层训练更稳定——我们这小模型不需要。

3. ignore_indexF.cross_entropy(..., ignore_index=0) 让 loss 跳过 <pad> token 的位置。这样我们可以把不同长度的序列 padding 到统一长度而不污染 loss。

9.5 训练循环

import time
from torch.optim.lr_scheduler import LambdaLR


def get_lr_scheduler(optimizer, warmup_steps, total_steps, min_lr_ratio=0.1):
    """
    线性 warmup + cosine decay
    """
    def lr_lambda(step):
        if step < warmup_steps:
            return step / max(1, warmup_steps)
        # cosine decay 从 1.0 到 min_lr_ratio
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return min_lr_ratio + 0.5 * (1 - min_lr_ratio) * (1 + math.cos(math.pi * progress))
    return LambdaLR(optimizer, lr_lambda)


def train(model, loader, total_steps=20000, lr=3e-4, warmup=2000,
          grad_clip=1.0, log_interval=100, device='cuda'):
    model.to(device)
    model.train()

    # AdamW 是 GPT 系列的标配
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.95),                  # GPT 风格 beta2=0.95,比默认 0.999 小
        weight_decay=0.1,
    )
    scheduler = get_lr_scheduler(optimizer, warmup, total_steps)

    step = 0
    t0 = time.time()
    losses = []
    iter_loader = iter(loader)
    while step < total_steps:
        try:
            x, y = next(iter_loader)
        except StopIteration:
            iter_loader = iter(loader)
            x, y = next(iter_loader)

        x, y = x.to(device), y.to(device)
        logits, loss = model(x, targets=y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        # 梯度裁剪 —— 防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        scheduler.step()

        losses.append(loss.item())
        step += 1

        if step % log_interval == 0:
            avg = sum(losses[-log_interval:]) / log_interval
            elapsed = time.time() - t0
            tokens_per_sec = (log_interval * x.numel()) / elapsed
            print(f"step {step:6d} | lr {scheduler.get_last_lr()[0]:.2e} | "
                  f"loss {avg:.4f} | {tokens_per_sec:.0f} tok/s")
            t0 = time.time()

几个工程细节:

优化器 AdamW:beta2=0.95(不是默认 0.999)来自 GPT-2 论文——大模型训练经验显示 beta2 小一些更稳定。weight_decay=0.1 给参数加 L2 正则。

Learning rate schedule:线性 warmup(前 2000 步从 0 升到 lr)+ cosine decay(从 lr 降到 0.1×lr)。这是 Llama / GPT 系列标配的 schedule。

Gradient clippingclip_grad_norm_(model.parameters(), 1.0) 把梯度的全局范数裁剪到 1.0。这是防止梯度爆炸(特别是训练初期)的标准做法。

zero_grad(set_to_none=True):把梯度设为 None 而不是 0——节省一次显存写入,速度略快。

9.6 生成:top-k + temperature 采样

训练完模型,要让它「写诗」需要一个生成函数。最朴素的「贪心生成」(每次取最大概率的 token)输出会很无聊,重复多。生产级用 top-k + temperature

@torch.no_grad()
def generate(model, prompt_ids, max_new_tokens=64, temperature=0.8, top_k=50):
    model.eval()
    ids = torch.tensor([prompt_ids], dtype=torch.long, device=next(model.parameters()).device)
    
    for _ in range(max_new_tokens):
        # 截断:只用最后 max_seq_len 个 token(如果 prompt 太长)
        ids_cond = ids[:, -512:]
        
        # 前向得到下一 token 的 logits
        logits = model(ids_cond)            # (1, T, V)
        logits = logits[:, -1, :] / temperature   # 只看最后一个位置;温度缩放
        
        # top-k 截断:保留概率最大的 k 个候选
        v, _ = torch.topk(logits, top_k)
        logits[logits < v[:, [-1]]] = float('-inf')
        
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        
        # 遇到 <eos> 停止
        if next_id.item() == char2id['<eos>']:
            break
        
        ids = torch.cat([ids, next_id], dim=1)
    
    return ids.squeeze().tolist()

temperature 是「采样的随机性」。温度小(如 0.5)logits 被放大,分布尖锐,结果保守;温度大(如 1.2)分布平坦,结果更随机更有创意。古诗这种「需要严谨格律 + 一点意外」的任务,0.7-0.9 是甜点。

top-k 是「只看概率前 k 大的候选」。把极小概率的 token 过滤掉,避免生成奇怪的字符。50 是常用值;古诗任务可以小一些(10-30)保持格律。

更进阶的采样方法(top-p / nucleus、min-p、temperature scaling 等)原理类似但更精细——nanoGPT、HuggingFace Transformers 都有现成实现。

9.7 跑一次完整训练

把所有零件串起来:

# 配置
VOCAB_SIZE = len(char2id)
D_MODEL = 384
N_HEADS = 6
N_LAYERS = 6
D_FFN = 1024
BLOCK_SIZE = 128
BATCH_SIZE = 64
TOTAL_STEPS = 20000

# 模型
model = MiniGPT(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    d_ffn=D_FFN,
    max_seq_len=BLOCK_SIZE * 2,        # RoPE 预计算容量留余量
)

# 数据
dataset = PoemDataset(all_ids, block_size=BLOCK_SIZE)
loader = torch.utils.data.DataLoader(
    dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)

# 训练
train(model, loader, total_steps=TOTAL_STEPS, lr=3e-4, warmup=1000, device='cuda')

# 保存
torch.save(model.state_dict(), 'mini_gpt_poem.pt')

在一张 RTX 4060(16GB)上,这个配置每秒大约处理 30K-50K token,20000 步训练约 2-3 小时完成。loss 从初始的约 8.0(接近 ln(VOCAB_SIZE) ≈ 8.99,对应均匀分布)降到约 3.5-4.5——已经远好于均匀分布,说明模型学到了相当多结构。

flowchart LR
  A[step 0<br/>loss ~ 8.99<br/>均匀分布] --> B[step 500<br/>loss ~ 6.0<br/>学会字频]
  B --> C[step 2000<br/>loss ~ 5.0<br/>学会常见词]
  C --> D[step 5000<br/>loss ~ 4.3<br/>学会平仄]
  D --> E[step 10000<br/>loss ~ 4.0<br/>学会五言/七言]
  E --> F[step 20000<br/>loss ~ 3.7<br/>能写完整诗]

9.8 生成示例:让它写诗

训练完后看看模型能写出什么:

def generate_poem(prompt, max_tokens=50, temperature=0.8, top_k=20):
    prompt_ids = [char2id['<bos>']] + encode(prompt)
    out_ids = generate(model, prompt_ids, max_new_tokens=max_tokens,
                       temperature=temperature, top_k=top_k)
    return ''.join(id2char[i] for i in out_ids if i not in (0, 1, 2))

print(generate_poem("春风", temperature=0.8))
print(generate_poem("月落", temperature=0.7))
print(generate_poem("故人", temperature=0.9))

实测:1.71M 参数的 mini-GPT 在 19,397 首唐诗上跑 1500 步

下面是本书配套代码(docs/books/transformer/code/)的真实跑通记录——为了让 Mac 笔记本也能在 10 分钟内跑完,我把模型缩到 d_model=128 / 4 层 / 4 头(1.71M 参数),用 chinese-poetry 仓库的前 20K 首唐诗做训练集(去重后 19,397 首,词表 6,705 个汉字)。

Loss 曲线(CPU 上 1500 步约 8 分 35 秒):

mini-GPT 训练 loss 曲线

从初始 8.74(接近 ln(6705) = 8.81,模型只是均匀分布)一路降到 4.65——前 200 步模型在快速学字频,200-1000 步学常用搭配,1000 步之后是边际优化。这条曲线就是「模型在 8 分钟里学会唐诗」的全部过程。

实际生成示例(temperature=0.85, top_k=15):

「春风」→ 春?青青山,春風吹錦衾。自憐歌管弦,更作舞人衣。
「月落」→ 月落花飛盡,山寒月滿深。風流萬里外,月滿兩山時。
「故人」→ 故人行樂少年少,一去春風不得歸。一去不知無事事,更知長者到人心。
「清夜」→ 清夜無風景,秋天一夜寒。雲高風景靜,山上日星寒。

(数据是 chinese-poetry 仓库的繁体字版本,所以输出也是繁体。「?」是 <unk> 的占位——某个低频字不在 6,705 词表里。)

注意几个有意思的现象:

  1. 字数对齐:模型自动学到了五言("月落花飛盡"5 字)、七言("故人行樂少年少"7 字)的节奏——这是从字符级 token 预测中自发涌现的格律。
  2. 意境关联:「春风→吹」「月落→花飞」「故人→行乐」「清夜→无风景」——这些搭配反映了 attention 学到的字间语义聚类。
  3. 失败模式:「一去春风不得归。一去不知无事事」——同一开头重复出现,是 1.71M 小模型容量不足的典型征兆。
  4. 押韵的萌芽:「秋天一夜寒…山上日星寒」——「寒」字押韵涌现,但还不严格。

不完美的地方:

  1. 平仄不严:律诗对平仄有严格要求,1.7M 的小模型学不到这层。
  2. 少数生成像「拼贴」:base 模型常见的「重复学习内容」——足够大的模型 + 多样的数据可以缓解。
  3. 逻辑连贯性有限:能做到字面通顺,但深层语义连贯不如大模型。

这是 1.71M 小模型在唐诗上 8 分钟训练的合理表现——和 GPT-2 1.5B 比当然差很多,但作为「Transformer 学习项目」已经能给读者非常直观的成就感。把 d_model 加到 384 / 层数加到 6 / 训练步数加到 20000,loss 能进一步降到 3.7 左右——基本接近「能写出像样七律」的水平。

9.9 训练时常见问题

问题 1:Loss 不降或者飙到 NaN

最常见原因:

问题 2:训练 loss 很低但生成很烂

原因:过拟合。模型记住了训练数据的局部模式但没学到泛化能力。

调试方法:拿一份 held-out validation set(比如 5% 的诗),同步监测 val loss。如果 val loss 比 train loss 高很多(差距 > 0.5)就是过拟合。

应对:

问题 3:生成结果重复

原因:贪心或低 temperature 时模型陷入局部模式重复。

应对:

问题 4:训练慢

排查:

加上 mixed precision + torch.compile

model = torch.compile(model)
scaler = torch.amp.GradScaler('cuda')

# 训练循环里:
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
    logits, loss = model(x, targets=y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

9.10 从 mini-GPT 到 Llama:缺什么

我们这个 mini-GPT 已经包含了 Llama 的所有核心组件。要从这扩展到真正的 Llama 7B,主要差异是:

维度 mini-GPT Llama 7B 差距
模型规模 ~10M 7B 700×
训练数据 ~3M token 1-2T token 500K×
上下文 128 4096+ 32×
词表 ~8K(字符级) 32K(BPE)
Attention MHA GQA(h_kv=4 或 8) 优化
FFN SwiGLU SwiGLU
位置编码 RoPE RoPE
归一化 RMSNorm Pre-LN RMSNorm Pre-LN
训练框架 单卡 PyTorch 多卡 / Megatron / DeepSpeed 分布式
后训练 SFT + RLHF/DPO 对齐

也就是说架构层面我们已经基本和 Llama 一致——剩下的「鸿沟」全在工程:分布式训练、超大数据集处理、长上下文处理、对齐训练。这些每一项都值得另开一本书来讲。

第 10 章会先把「词表 / Tokenizer」这一格补全——讲清楚字符级、词级、BPE / SentencePiece / Tiktoken 这条演化路线,让你能从「字符级 mini-GPT」升级到工业级 tokenizer。

本章小结

  1. mini-GPT 是第 5 章 Block + 第 8 章 Attention 的完整组装——加 token embedding、RMSNorm、SwiGLU FFN、LM head 就成了一个能训的 Decoder-only Transformer。
  2. 字符级 tokenizer 适合中文小数据场景——简单粗暴,跳过 BPE 工程。
  3. 训练循环的关键齿轮:AdamW(beta2=0.95,weight_decay=0.1)+ 线性 warmup + cosine decay + grad clipping。
  4. 采样三件套:temperature、top-k、top-p——温度调随机性、top-k/p 截断尾部。
  5. 10M 模型在《全唐诗》上能学到格律和押韵——loss 从 8.99 降到 ~3.7,生成结果可识别地像古诗。
  6. 常见 bug:lr 太大爆 NaN、过拟合 val loss 上升、生成重复——解决方法是降 lr / 加 dropout / 提温度。
  7. 从 mini-GPT 到 Llama 7B 的鸿沟全在工程:分布式训练、数据规模、tokenizer 工程、长上下文、对齐。架构本身已经在我们手里了。

下一章讲 tokenizer 工程:从字符级到 BPE 到 SentencePiece,理解今天大模型为什么用 BPE / Tiktoken 这套体系,以及 tokenizer 选择对模型性能的微妙影响。

延伸阅读