Transformer 解剖:从 Attention 到推理系统

第 8 章 50 行 PyTorch 实现 Self-Attention

作者 杨艺韬 · 4,497 字

第 8 章 50 行 PyTorch 实现 Self-Attention

第 2 到第 5 章我们把 Self-Attention 的数学拆透了。这一章把它落到代码上——用纯 PyTorch 写一个能跑、能调、能验证正确性的 Self-Attention 模块。

为什么要亲手写一遍?读公式和读代码是两件事。公式可以掩盖工程细节——张量怎么 reshape、mask 怎么 broadcast、softmax 沿哪一维做、混合精度下哪些算子要 cast——这些只有在写代码时才会暴露。亲手写一遍 Self-Attention 之后,你看 vLLM、Flash Attention、PyTorch 内置 nn.MultiheadAttention 时不会再有任何「不知道里面在干嘛」的困惑。

我们会按下面的顺序逐步搭:

  1. 最朴素版 single-head:把第 2 章的公式直接翻译成 12 行代码
  2. 升级到 multi-head:把第 3 章的「打包成大矩阵乘」工程实现写出来
  3. 加 causal mask:让它能用于 Decoder-only 自回归
  4. 加 RoPE:把第 4 章的旋转位置编码接上
  5. 封装成 nn.Module:可以直接塞进 Transformer Block

最终代码不到 80 行,但它就是 Llama / GPT 推理时跑的那段 attention——只是没做 Flash Attention 的 SRAM 优化(那是第 18 章)。

8.1 准备:环境与张量约定

先约定本章的张量形状记号:

代码用 PyTorch 写,假设你已经熟悉 torch.nn.Modulenn.Linear、自动求导这些基础。如果不熟,建议先看一遍《动手学深度学习》第 4-5 章。

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

8.2 第一版:12 行的朴素 Self-Attention

我们先写最贴近第 2 章公式 Attention(Q,K,V)=softmax(QKT/dk)V\text{Attention}(Q, K, V) = \text{softmax}(QK^T/\sqrt{d_k})V 的版本。这版只有 single head、不带 mask、不带位置编码——纯粹的「数学翻译」。

def naive_attention(x, W_q, W_k, W_v):
    """
    x:   (B, T, D)
    W_*: (D, D)  Q/K/V 投影矩阵
    return: (B, T, D)
    """
    Q = x @ W_q                     # (B, T, D)
    K = x @ W_k                     # (B, T, D)
    V = x @ W_v                     # (B, T, D)

    d_k = Q.size(-1)
    scores = Q @ K.transpose(-2, -1)            # (B, T, T)
    scores = scores / math.sqrt(d_k)            # 缩放
    attn = F.softmax(scores, dim=-1)            # 沿 K 维归一化
    out = attn @ V                              # (B, T, D)
    return out

12 行(注释除外)。每一行都对应公式里的一个步骤:

公式 代码
Q=XWQQ = X W_Q Q = x @ W_q
K=XWKK = X W_K K = x @ W_k
V=XWVV = X W_V V = x @ W_v
S=QKTS = QK^T scores = Q @ K.transpose(-2, -1)
S=S/dkS' = S / \sqrt{d_k} scores = scores / math.sqrt(d_k)
A=softmax(S)A = \text{softmax}(S') attn = F.softmax(scores, dim=-1)
out=AV\text{out} = AV out = attn @ V

注意几个 PyTorch 细节:

跑一下让它工作:

B, T, D = 2, 5, 16
x = torch.randn(B, T, D)
W_q, W_k, W_v = (torch.randn(D, D) for _ in range(3))
out = naive_attention(x, W_q, W_k, W_v)
print(out.shape)  # torch.Size([2, 5, 16])

正确输出形状 (2, 5, 16)——和输入一致。这样我们的「最小可工作版」就跑通了。

8.3 第二版:multi-head 的工程实现

现在升级到 Multi-Head。回忆第 3 章 3.3 节,所有头被打包到一次大矩阵乘法里——不是 for 循环跑 H 次,而是用 reshape 把维度从 (B, T, D) 拆成 (B, H, T, Dh)

def multi_head_attention(x, W_q, W_k, W_v, W_o, H):
    """
    x:    (B, T, D)
    W_q/k/v: (D, D)  打包后的 QKV 投影(每个头的投影拼起来)
    W_o:  (D, D)     输出投影
    H:    int        头数
    """
    B, T, D = x.shape
    Dh = D // H

    # 一次大矩阵乘 → 拆成多头
    Q = (x @ W_q).view(B, T, H, Dh).transpose(1, 2)   # (B, H, T, Dh)
    K = (x @ W_k).view(B, T, H, Dh).transpose(1, 2)   # (B, H, T, Dh)
    V = (x @ W_v).view(B, T, H, Dh).transpose(1, 2)   # (B, H, T, Dh)

    scores = Q @ K.transpose(-2, -1) / math.sqrt(Dh)  # (B, H, T, T)
    attn = F.softmax(scores, dim=-1)                   # (B, H, T, T)
    out = attn @ V                                     # (B, H, T, Dh)

    # 把多头合回去
    out = out.transpose(1, 2).contiguous().view(B, T, D)  # (B, T, D)
    return out @ W_o                                       # (B, T, D)

关键变化在 view + transpose 这两行:

(x @ W_q)             # (B, T, D)
   .view(B, T, H, Dh) # (B, T, H, Dh)  把 D 拆成 H × Dh
   .transpose(1, 2)   # (B, H, T, Dh)  把 H 调到前面

为什么要 .transpose(1, 2)?因为后面的 attention 计算 Q @ K.transpose(-2, -1) 是在最后两维上做矩阵乘——我们希望「头维」是一个独立的 batch 维(让 PyTorch 自动并行处理 H 个头),所以要把它放到 batch 那一侧。

注意 attention 计算后的 out 形状是 (B, H, T, Dh),要合回 (B, T, D)

out.transpose(1, 2)            # (B, T, H, Dh)
   .contiguous()               # 物理上重新排布内存
   .view(B, T, D)              # (B, T, D)  H × Dh 重新合成 D

.contiguous() 是必须的——transpose 只是改了 stride 信息(虚拟地交换维度),物理内存还是原样;view 要求张量在内存中是连续的,所以要先 .contiguous()

最后一行 out @ W_o 是输出投影(第 3.2 节的 WOW_O),把多头拼起来后做一次最终的线性变换。

flowchart LR
  X["x (B,T,D)"] --> WQ["× W_q"]
  WQ --> Q1["(B,T,D)"]
  Q1 --> RV["view (B,T,H,Dh)"]
  RV --> TR["transpose (B,H,T,Dh)"]
  TR --> Q[Q]
  Q --> ATT["batched<br/>scaled<br/>dot product<br/>attention"]
  X --> KV[同样得到 K, V]
  KV --> ATT
  ATT --> O1["(B,H,T,Dh)"]
  O1 --> TR2["transpose + reshape (B,T,D)"]
  TR2 --> WO["× W_o"]
  WO --> OUT["输出 (B,T,D)"]

8.4 第三版:加上 causal mask

Decoder-only 模型(GPT、Llama)需要因果掩码——位置 i 的 query 只能看 ≤i 的 key。第 2.10 节我们讲过这件事的实现:在 softmax 之前把上三角 score 设为 −∞。

PyTorch 里有个常用工具 torch.tril

def causal_mask(T):
    """
    返回一个 (T, T) 的下三角矩阵,下三角和对角线为 0,上三角为 -inf
    """
    mask = torch.full((T, T), float('-inf'))
    mask = torch.triu(mask, diagonal=1)
    return mask  # 上三角是 -inf,其他是 0

或者更直接:

mask = torch.zeros(T, T)
mask[torch.triu(torch.ones(T, T), diagonal=1).bool()] = float('-inf')

把它加到 score 上:

def causal_attention(x, W_q, W_k, W_v, W_o, H):
    B, T, D = x.shape
    Dh = D // H

    Q = (x @ W_q).view(B, T, H, Dh).transpose(1, 2)
    K = (x @ W_k).view(B, T, H, Dh).transpose(1, 2)
    V = (x @ W_v).view(B, T, H, Dh).transpose(1, 2)

    scores = Q @ K.transpose(-2, -1) / math.sqrt(Dh)   # (B, H, T, T)

    # —— 关键:加上因果掩码 ——
    mask = torch.triu(torch.full((T, T), float('-inf')), diagonal=1)
    scores = scores + mask.to(scores.device)            # 自动 broadcast 到 (B, H, T, T)

    attn = F.softmax(scores, dim=-1)
    out = attn @ V

    out = out.transpose(1, 2).contiguous().view(B, T, D)
    return out @ W_o

加 mask 这一步几个细节:

验证一下 mask 的效果:

B, T, D, H = 1, 4, 8, 2
x = torch.randn(B, T, D)
W_q, W_k, W_v, W_o = (torch.randn(D, D) for _ in range(4))

# 两个版本的输出
out_full   = multi_head_attention(x, W_q, W_k, W_v, W_o, H)
out_causal = causal_attention(x, W_q, W_k, W_v, W_o, H)

# causal 版本的位置 0 应该和「只把序列截断到位置 0 时」算出来的输出一致
x_truncated = x[:, :1, :]                                # (1, 1, D)
out_truncated = causal_attention(x_truncated, W_q, W_k, W_v, W_o, H)
print(torch.allclose(out_causal[:, :1], out_truncated))  # 应该是 True

这是 causal mask 必备的「自一致性」检验:截断输入后第 0 个位置的输出,应该等于完整输入下第 0 个位置的输出(因为 causal mask 让位置 0 看不到位置 1, 2, 3)。

8.5 第四版:加上 RoPE

我们继续把第 4 章的 RoPE 加进来。RoPE 旋转的对象是 Q 和 K(不是 V),位置在「投影后、attention 计算前」。

def precompute_rope(Dh, max_T, base=10000):
    """
    预计算 RoPE 用的 cos / sin 表
    Dh: 每个头的维度
    max_T: 最大序列长度
    return: cos, sin (max_T, Dh)
    """
    inv_freq = 1.0 / (base ** (torch.arange(0, Dh, 2).float() / Dh))   # (Dh/2,)
    pos = torch.arange(max_T).float()                                    # (max_T,)
    freqs = pos.unsqueeze(-1) * inv_freq.unsqueeze(0)                    # (max_T, Dh/2)
    # 把每个频率重复一次(对 (real, imag) 两个分量都用同一频率)
    cos = freqs.cos().repeat_interleave(2, dim=-1)                       # (max_T, Dh)
    sin = freqs.sin().repeat_interleave(2, dim=-1)                       # (max_T, Dh)
    return cos, sin

def apply_rope(x, cos, sin):
    """
    x:   (B, H, T, Dh)
    cos: (T, Dh)
    sin: (T, Dh)
    return: 旋转后的 x
    """
    # 把 x 拆成偶数维和奇数维交错对
    x1, x2 = x[..., 0::2], x[..., 1::2]                                  # (B, H, T, Dh/2)
    x_rot = torch.stack([-x2, x1], dim=-1).flatten(-2)                   # (B, H, T, Dh)
    # cos / sin 的 broadcast
    cos = cos[None, None, :, :]                                           # (1, 1, T, Dh)
    sin = sin[None, None, :, :]
    return x * cos + x_rot * sin

precompute_rope 是一次性的预计算——同一个模型只算一次,cache 起来反复用。apply_rope 是每次前向都要做的旋转操作。

把 RoPE 接到 attention 里:

def causal_attention_with_rope(x, W_q, W_k, W_v, W_o, H, cos, sin):
    B, T, D = x.shape
    Dh = D // H

    Q = (x @ W_q).view(B, T, H, Dh).transpose(1, 2)   # (B, H, T, Dh)
    K = (x @ W_k).view(B, T, H, Dh).transpose(1, 2)
    V = (x @ W_v).view(B, T, H, Dh).transpose(1, 2)

    # —— 新增:对 Q、K 应用 RoPE,V 不旋转 ——
    Q = apply_rope(Q, cos[:T], sin[:T])
    K = apply_rope(K, cos[:T], sin[:T])

    scores = Q @ K.transpose(-2, -1) / math.sqrt(Dh)
    mask = torch.triu(torch.full((T, T), float('-inf')), diagonal=1)
    scores = scores + mask.to(scores.device)
    attn = F.softmax(scores, dim=-1)
    out = attn @ V

    out = out.transpose(1, 2).contiguous().view(B, T, D)
    return out @ W_o

注意 cos[:T]sin[:T] 是 slice——只取当前序列长度需要的部分(max_T 是预计算的最大值,实际序列可能更短)。

8.6 第五版:封装成 nn.Module

最后我们把所有零件封装成一个 nn.Module,方便塞进完整的 Transformer:

class CausalSelfAttention(nn.Module):
    """
    Decoder-only Transformer 的核心 attention 模块。
    包含: Multi-Head + RoPE + Causal Mask
    """
    def __init__(self, d_model, n_heads, max_seq_len=8192, rope_base=10000):
        super().__init__()
        assert d_model % n_heads == 0, "d_model 必须能被 n_heads 整除"
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        # QKV 三个投影合并成一次大 Linear(fused),减少一次 launch 开销
        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

        # RoPE 预计算(注册为 buffer,不参与训练但跟随 .to(device))
        cos, sin = self._precompute_rope(self.d_head, max_seq_len, rope_base)
        self.register_buffer("rope_cos", cos, persistent=False)
        self.register_buffer("rope_sin", sin, persistent=False)

    @staticmethod
    def _precompute_rope(d_head, max_t, base):
        inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head))
        pos = torch.arange(max_t).float()
        freqs = pos.unsqueeze(-1) * inv_freq.unsqueeze(0)
        cos = freqs.cos().repeat_interleave(2, dim=-1)
        sin = freqs.sin().repeat_interleave(2, dim=-1)
        return cos, sin

    @staticmethod
    def _apply_rope(x, cos, sin):
        x1, x2 = x[..., 0::2], x[..., 1::2]
        x_rot = torch.stack([-x2, x1], dim=-1).flatten(-2)
        return x * cos[None, None] + x_rot * sin[None, None]

    def forward(self, x):
        B, T, D = x.shape
        H, Dh = self.n_heads, self.d_head

        # 一次 Linear 算出 QKV
        qkv = self.qkv_proj(x)                                 # (B, T, 3D)
        Q, K, V = qkv.split(D, dim=-1)                          # 每个 (B, T, D)
        Q = Q.view(B, T, H, Dh).transpose(1, 2)                # (B, H, T, Dh)
        K = K.view(B, T, H, Dh).transpose(1, 2)
        V = V.view(B, T, H, Dh).transpose(1, 2)

        # RoPE
        Q = self._apply_rope(Q, self.rope_cos[:T], self.rope_sin[:T])
        K = self._apply_rope(K, self.rope_cos[:T], self.rope_sin[:T])

        # Scaled dot-product
        scores = Q @ K.transpose(-2, -1) / math.sqrt(Dh)       # (B, H, T, T)

        # Causal mask
        mask = torch.triu(torch.full((T, T), float('-inf'), device=x.device), diagonal=1)
        scores = scores + mask

        attn = F.softmax(scores, dim=-1)
        out = attn @ V                                          # (B, H, T, Dh)

        # Concat heads + output projection
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.out_proj(out)

总共 50 行左右(含注释)。这就是一个 production-grade 的 Causal Self-Attention with RoPE。

测试一下它能跑:

torch.manual_seed(42)
B, T, D, H = 2, 16, 128, 4
attn = CausalSelfAttention(d_model=D, n_heads=H, max_seq_len=64)
x = torch.randn(B, T, D)
out = attn(x)
print(out.shape)              # torch.Size([2, 16, 128])
print(out.requires_grad)      # True,可以反向传播

8.7 几个进阶工程考虑

Fused QKV 投影

我们用 nn.Linear(d_model, 3 * d_model) 把 Q/K/V 三个投影合并成一次。这是一个常见优化:

GPT-2、Llama 等主流模型的实现都用这种 fused 形式。但要注意:用 GQA / MQA 时,K 和 V 的输出维度小于 Q(因为它们的头数少),不能简单地用 3 * d_model——要分开处理:

# GQA 情况
self.q_proj = nn.Linear(d_model, n_heads * d_head, bias=False)
self.k_proj = nn.Linear(d_model, n_kv_heads * d_head, bias=False)
self.v_proj = nn.Linear(d_model, n_kv_heads * d_head, bias=False)

混合精度训练

实际工程几乎一定用混合精度(FP16 / BF16)训练:

attn = CausalSelfAttention(D, H).to(torch.bfloat16).cuda()
x = torch.randn(B, T, D, dtype=torch.bfloat16, device='cuda')
out = attn(x)  # 全部在 BF16 下计算

但有一个数值稳定细节:softmax 一定要在 FP32 下做。因为 softmax 涉及 exp(),BF16 的精度可能让 exp 结果溢出或下溢。PyTorch 的 F.softmax 默认会自动 upcast 到 FP32 计算再 downcast,但如果你手动写,要注意:

scores_fp32 = scores.float()             # cast 到 FP32
attn = F.softmax(scores_fp32, dim=-1)
attn = attn.to(scores.dtype)             # cast 回原 dtype

用 PyTorch 内置的 scaled_dot_product_attention

PyTorch 2.0+ 提供了 F.scaled_dot_product_attention(SDPA)——它内置了 Flash Attention 等优化:

out = F.scaled_dot_product_attention(
    Q, K, V,
    attn_mask=None,
    is_causal=True,           # 自动应用 causal mask,不需要手动构造
    dropout_p=0.0,
)

实际工程里,写完原理性的代码后,应该在生产时切换到 SDPA——它能自动选择 Flash Attention / Memory-Efficient Attention 等更快的实现。我们在第 18 章会讲 Flash Attention 内部到底怎么做的。

8.8 调试技巧

写 attention 时容易踩的几个坑,调试技巧如下:

坑 1:softmax 沿错维度

F.softmax(scores, dim=-1)dim 必须是 K 维(最后一维)。如果 attention 矩阵形状是 (B, H, T_q, T_k),softmax 应该归一化 T_k 维。写错成 dim=-2 就完蛋。

调试方法:检查 attention 矩阵每一行的和是否为 1:

print(attn.sum(dim=-1))  # 应该全是 1.0

坑 2:mask 形状错位

mask 应该是 (T, T) 或可以 broadcast 到 (B, H, T, T) 的形状。如果你不小心写成了 (B, T) 这种形状(混淆了「padding mask」和「causal mask」),broadcast 会出意料的结果。

调试方法:手算一个小例子(比如 T=4),打印 mask 后的 score 矩阵,肉眼检查上三角是不是 -inf。

坑 3:viewtranspose 顺序错

(B, T, D) → (B, T, H, Dh).view 是对的;(B, T, H, Dh) → (B, H, T, Dh) 必须用 .transpose不能用 .view。原因:view 要求维度的展开顺序和内存连续顺序一致——你不能跳着重排维度。

调试方法:每次 reshape 后打印 stride:

print(x.stride())
print(x.is_contiguous())

坑 4:忘了 .contiguous()

transpose 之后 view 之前一定要 .contiguous(),否则 PyTorch 会报:

RuntimeError: view size is not compatible with input tensor's size and stride

坑 5:device / dtype 不一致

mask、cos、sin 这些预计算的张量要和 input 在同一个 device、同一个 dtype。register_buffer 是把它们「绑定」到模块上的标准做法——.to(device) 调用时会自动迁移。

8.9 用 unittest 做自一致性检验

最后给一个完整的测试脚本,验证我们的实现没问题:

def test_causal_self_attention():
    torch.manual_seed(42)
    B, T, D, H = 2, 16, 128, 4
    attn = CausalSelfAttention(d_model=D, n_heads=H, max_seq_len=64)
    x = torch.randn(B, T, D)

    # 1. 形状测试
    out = attn(x)
    assert out.shape == (B, T, D), f"shape mismatch: {out.shape}"

    # 2. 因果性测试: 截断输入后位置 0 输出应不变
    out_full = attn(x)
    out_truncated = attn(x[:, :8])
    assert torch.allclose(out_full[:, :8], out_truncated, atol=1e-5)

    # 3. 反向传播测试: 可以 backward 到 input
    loss = out.sum()
    loss.backward()
    print("✓ 所有测试通过")

test_causal_self_attention()

这个测试覆盖了:

实际写一个新 attention 实现时,先写测试再写代码——这样能在 100 行代码里少走一周弯路。

8.10 这版代码与 Llama / GPT 的差别

读到这里你可能会问:这 80 行代码和真正的 Llama / GPT 的 attention 代码差多少?

差别其实不大。Llama 2 / Llama 3 的 attention 实现核心和我们的版本几乎一样,主要差异在:

  1. GQA(Grouped-Query Attention):K 和 V 的头数是 Q 的 1/8。在我们的代码里要把 Q 投影维度(H * Dh)和 K/V 投影维度(H_kv * Dh)分开。
  2. Flash Attention:用 F.scaled_dot_product_attention 或 xformers 的优化版替换我们手写的「scores → softmax → attn @ V」三步——内存效率高很多(第 18 章详解)。
  3. KV Cache:推理时缓存 K、V 而不是每次重算(第 15 章详解)。
  4. MLA(Multi-Head Latent Attention,DeepSeek-V2/V3):把 KV 压缩到一个低维 latent 空间——和 GQA 思路类似,但更激进。

但所有这些扩展都是在我们这版代码的基础上加东西,不会推翻这个核心。所以你掌握了第 8 章这版代码,就掌握了今天主流大模型 attention 实现的 90%

剩下的 10%(KV Cache、Flash Attention)是第六部分(推理系统)的内容。

本章小结

  1. 从纸面到代码只需 12 行——naive attention 基本就是公式的逐行翻译。
  2. 多头的工程实现是 reshape 而不是 for 循环——所有头被打包到一次 batched matmul 里跑。
  3. causal mask 通过加上一个上三角 -inf 实现——softmax 后这些位置贡献 0,相当于「看不见未来」。
  4. RoPE 旋转的是 Q 和 K,不是 V——预计算 cos/sin 表,每次前向 broadcast 应用。
  5. 封装成 nn.Module 后这是一个 production-grade 的实现——用 buffer 注册预计算张量、用 fused QKV 投影减少 launch 开销。
  6. 几个常见坑:softmax 沿错维度、mask 形状错位、view/transpose 顺序错、忘 contiguous、device/dtype 不一致。
  7. PyTorch 2.0+ 的 SDPA 内置 Flash Attention,生产环境推荐用它替换手写实现。
  8. 这版代码和真正的 Llama / GPT 几乎一致,差别在 GQA、Flash Attention、KV Cache 这些工程优化。

下一章我们用这套 attention 搭一个完整的 mini-GPT,加上 embedding、FFN、LayerNorm、训练循环——目标是从零训练一个能写古诗的小模型。

延伸阅读