Transformer 解剖:从 Attention 到推理系统
第 8 章 50 行 PyTorch 实现 Self-Attention
第 8 章 50 行 PyTorch 实现 Self-Attention
第 2 到第 5 章我们把 Self-Attention 的数学拆透了。这一章把它落到代码上——用纯 PyTorch 写一个能跑、能调、能验证正确性的 Self-Attention 模块。
为什么要亲手写一遍?读公式和读代码是两件事。公式可以掩盖工程细节——张量怎么 reshape、mask 怎么 broadcast、softmax 沿哪一维做、混合精度下哪些算子要 cast——这些只有在写代码时才会暴露。亲手写一遍 Self-Attention 之后,你看 vLLM、Flash Attention、PyTorch 内置 nn.MultiheadAttention 时不会再有任何「不知道里面在干嘛」的困惑。
我们会按下面的顺序逐步搭:
- 最朴素版 single-head:把第 2 章的公式直接翻译成 12 行代码
- 升级到 multi-head:把第 3 章的「打包成大矩阵乘」工程实现写出来
- 加 causal mask:让它能用于 Decoder-only 自回归
- 加 RoPE:把第 4 章的旋转位置编码接上
- 封装成 nn.Module:可以直接塞进 Transformer Block
最终代码不到 80 行,但它就是 Llama / GPT 推理时跑的那段 attention——只是没做 Flash Attention 的 SRAM 优化(那是第 18 章)。
8.1 准备:环境与张量约定
先约定本章的张量形状记号:
B= batch size(一次处理几个样本)T= sequence length(序列长度)D=d_model(模型隐层维度,例如 512)H= number of heads(头数)Dh=d_model / H(每个头的维度,例如 64)
代码用 PyTorch 写,假设你已经熟悉 torch.nn.Module、nn.Linear、自动求导这些基础。如果不熟,建议先看一遍《动手学深度学习》第 4-5 章。
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
8.2 第一版:12 行的朴素 Self-Attention
我们先写最贴近第 2 章公式 的版本。这版只有 single head、不带 mask、不带位置编码——纯粹的「数学翻译」。
def naive_attention(x, W_q, W_k, W_v):
"""
x: (B, T, D)
W_*: (D, D) Q/K/V 投影矩阵
return: (B, T, D)
"""
Q = x @ W_q # (B, T, D)
K = x @ W_k # (B, T, D)
V = x @ W_v # (B, T, D)
d_k = Q.size(-1)
scores = Q @ K.transpose(-2, -1) # (B, T, T)
scores = scores / math.sqrt(d_k) # 缩放
attn = F.softmax(scores, dim=-1) # 沿 K 维归一化
out = attn @ V # (B, T, D)
return out
12 行(注释除外)。每一行都对应公式里的一个步骤:
| 公式 | 代码 |
|---|---|
Q = x @ W_q |
|
K = x @ W_k |
|
V = x @ W_v |
|
scores = Q @ K.transpose(-2, -1) |
|
scores = scores / math.sqrt(d_k) |
|
attn = F.softmax(scores, dim=-1) |
|
out = attn @ V |
注意几个 PyTorch 细节:
@是矩阵乘法:等价于torch.matmul。在 batched 张量上自动广播 batch 维。K.transpose(-2, -1)把最后两个维度交换,让K从(B, T, D)变成(B, D, T),方便和Q做Q @ K^T得到(B, T, T)。F.softmax(scores, dim=-1)沿最后一维(即 K 维)归一化。这一行如果dim写错(比如写成dim=-2),整个 attention 就崩了——这是新手最容易踩的坑之一。
跑一下让它工作:
B, T, D = 2, 5, 16
x = torch.randn(B, T, D)
W_q, W_k, W_v = (torch.randn(D, D) for _ in range(3))
out = naive_attention(x, W_q, W_k, W_v)
print(out.shape) # torch.Size([2, 5, 16])
正确输出形状 (2, 5, 16)——和输入一致。这样我们的「最小可工作版」就跑通了。
8.3 第二版:multi-head 的工程实现
现在升级到 Multi-Head。回忆第 3 章 3.3 节,所有头被打包到一次大矩阵乘法里——不是 for 循环跑 H 次,而是用 reshape 把维度从 (B, T, D) 拆成 (B, H, T, Dh):
def multi_head_attention(x, W_q, W_k, W_v, W_o, H):
"""
x: (B, T, D)
W_q/k/v: (D, D) 打包后的 QKV 投影(每个头的投影拼起来)
W_o: (D, D) 输出投影
H: int 头数
"""
B, T, D = x.shape
Dh = D // H
# 一次大矩阵乘 → 拆成多头
Q = (x @ W_q).view(B, T, H, Dh).transpose(1, 2) # (B, H, T, Dh)
K = (x @ W_k).view(B, T, H, Dh).transpose(1, 2) # (B, H, T, Dh)
V = (x @ W_v).view(B, T, H, Dh).transpose(1, 2) # (B, H, T, Dh)
scores = Q @ K.transpose(-2, -1) / math.sqrt(Dh) # (B, H, T, T)
attn = F.softmax(scores, dim=-1) # (B, H, T, T)
out = attn @ V # (B, H, T, Dh)
# 把多头合回去
out = out.transpose(1, 2).contiguous().view(B, T, D) # (B, T, D)
return out @ W_o # (B, T, D)
关键变化在 view + transpose 这两行:
(x @ W_q) # (B, T, D)
.view(B, T, H, Dh) # (B, T, H, Dh) 把 D 拆成 H × Dh
.transpose(1, 2) # (B, H, T, Dh) 把 H 调到前面
为什么要 .transpose(1, 2)?因为后面的 attention 计算 Q @ K.transpose(-2, -1) 是在最后两维上做矩阵乘——我们希望「头维」是一个独立的 batch 维(让 PyTorch 自动并行处理 H 个头),所以要把它放到 batch 那一侧。
注意 attention 计算后的 out 形状是 (B, H, T, Dh),要合回 (B, T, D):
out.transpose(1, 2) # (B, T, H, Dh)
.contiguous() # 物理上重新排布内存
.view(B, T, D) # (B, T, D) H × Dh 重新合成 D
.contiguous() 是必须的——transpose 只是改了 stride 信息(虚拟地交换维度),物理内存还是原样;view 要求张量在内存中是连续的,所以要先 .contiguous()。
最后一行 out @ W_o 是输出投影(第 3.2 节的 ),把多头拼起来后做一次最终的线性变换。
flowchart LR X["x (B,T,D)"] --> WQ["× W_q"] WQ --> Q1["(B,T,D)"] Q1 --> RV["view (B,T,H,Dh)"] RV --> TR["transpose (B,H,T,Dh)"] TR --> Q[Q] Q --> ATT["batched<br/>scaled<br/>dot product<br/>attention"] X --> KV[同样得到 K, V] KV --> ATT ATT --> O1["(B,H,T,Dh)"] O1 --> TR2["transpose + reshape (B,T,D)"] TR2 --> WO["× W_o"] WO --> OUT["输出 (B,T,D)"]
8.4 第三版:加上 causal mask
Decoder-only 模型(GPT、Llama)需要因果掩码——位置 i 的 query 只能看 ≤i 的 key。第 2.10 节我们讲过这件事的实现:在 softmax 之前把上三角 score 设为 −∞。
PyTorch 里有个常用工具 torch.tril:
def causal_mask(T):
"""
返回一个 (T, T) 的下三角矩阵,下三角和对角线为 0,上三角为 -inf
"""
mask = torch.full((T, T), float('-inf'))
mask = torch.triu(mask, diagonal=1)
return mask # 上三角是 -inf,其他是 0
或者更直接:
mask = torch.zeros(T, T)
mask[torch.triu(torch.ones(T, T), diagonal=1).bool()] = float('-inf')
把它加到 score 上:
def causal_attention(x, W_q, W_k, W_v, W_o, H):
B, T, D = x.shape
Dh = D // H
Q = (x @ W_q).view(B, T, H, Dh).transpose(1, 2)
K = (x @ W_k).view(B, T, H, Dh).transpose(1, 2)
V = (x @ W_v).view(B, T, H, Dh).transpose(1, 2)
scores = Q @ K.transpose(-2, -1) / math.sqrt(Dh) # (B, H, T, T)
# —— 关键:加上因果掩码 ——
mask = torch.triu(torch.full((T, T), float('-inf')), diagonal=1)
scores = scores + mask.to(scores.device) # 自动 broadcast 到 (B, H, T, T)
attn = F.softmax(scores, dim=-1)
out = attn @ V
out = out.transpose(1, 2).contiguous().view(B, T, D)
return out @ W_o
加 mask 这一步几个细节:
- broadcast 自动处理:
mask的形状是(T, T),加到(B, H, T, T)的scores上时 PyTorch 自动复制 B 和 H 维度。 - 数值稳定:用
-inf直接加;exp(-inf) = 0,softmax 结果中这些位置就贡献 0。一些工程实现会用-1e9代替-inf,避免某些设备下 NaN 问题。 - device 一致:
mask.to(scores.device)是关键——如果 scores 在 GPU 上而 mask 在 CPU 上,会触发跨设备运行时错误。
验证一下 mask 的效果:
B, T, D, H = 1, 4, 8, 2
x = torch.randn(B, T, D)
W_q, W_k, W_v, W_o = (torch.randn(D, D) for _ in range(4))
# 两个版本的输出
out_full = multi_head_attention(x, W_q, W_k, W_v, W_o, H)
out_causal = causal_attention(x, W_q, W_k, W_v, W_o, H)
# causal 版本的位置 0 应该和「只把序列截断到位置 0 时」算出来的输出一致
x_truncated = x[:, :1, :] # (1, 1, D)
out_truncated = causal_attention(x_truncated, W_q, W_k, W_v, W_o, H)
print(torch.allclose(out_causal[:, :1], out_truncated)) # 应该是 True
这是 causal mask 必备的「自一致性」检验:截断输入后第 0 个位置的输出,应该等于完整输入下第 0 个位置的输出(因为 causal mask 让位置 0 看不到位置 1, 2, 3)。
8.5 第四版:加上 RoPE
我们继续把第 4 章的 RoPE 加进来。RoPE 旋转的对象是 Q 和 K(不是 V),位置在「投影后、attention 计算前」。
def precompute_rope(Dh, max_T, base=10000):
"""
预计算 RoPE 用的 cos / sin 表
Dh: 每个头的维度
max_T: 最大序列长度
return: cos, sin (max_T, Dh)
"""
inv_freq = 1.0 / (base ** (torch.arange(0, Dh, 2).float() / Dh)) # (Dh/2,)
pos = torch.arange(max_T).float() # (max_T,)
freqs = pos.unsqueeze(-1) * inv_freq.unsqueeze(0) # (max_T, Dh/2)
# 把每个频率重复一次(对 (real, imag) 两个分量都用同一频率)
cos = freqs.cos().repeat_interleave(2, dim=-1) # (max_T, Dh)
sin = freqs.sin().repeat_interleave(2, dim=-1) # (max_T, Dh)
return cos, sin
def apply_rope(x, cos, sin):
"""
x: (B, H, T, Dh)
cos: (T, Dh)
sin: (T, Dh)
return: 旋转后的 x
"""
# 把 x 拆成偶数维和奇数维交错对
x1, x2 = x[..., 0::2], x[..., 1::2] # (B, H, T, Dh/2)
x_rot = torch.stack([-x2, x1], dim=-1).flatten(-2) # (B, H, T, Dh)
# cos / sin 的 broadcast
cos = cos[None, None, :, :] # (1, 1, T, Dh)
sin = sin[None, None, :, :]
return x * cos + x_rot * sin
precompute_rope 是一次性的预计算——同一个模型只算一次,cache 起来反复用。apply_rope 是每次前向都要做的旋转操作。
把 RoPE 接到 attention 里:
def causal_attention_with_rope(x, W_q, W_k, W_v, W_o, H, cos, sin):
B, T, D = x.shape
Dh = D // H
Q = (x @ W_q).view(B, T, H, Dh).transpose(1, 2) # (B, H, T, Dh)
K = (x @ W_k).view(B, T, H, Dh).transpose(1, 2)
V = (x @ W_v).view(B, T, H, Dh).transpose(1, 2)
# —— 新增:对 Q、K 应用 RoPE,V 不旋转 ——
Q = apply_rope(Q, cos[:T], sin[:T])
K = apply_rope(K, cos[:T], sin[:T])
scores = Q @ K.transpose(-2, -1) / math.sqrt(Dh)
mask = torch.triu(torch.full((T, T), float('-inf')), diagonal=1)
scores = scores + mask.to(scores.device)
attn = F.softmax(scores, dim=-1)
out = attn @ V
out = out.transpose(1, 2).contiguous().view(B, T, D)
return out @ W_o
注意 cos[:T] 和 sin[:T] 是 slice——只取当前序列长度需要的部分(max_T 是预计算的最大值,实际序列可能更短)。
8.6 第五版:封装成 nn.Module
最后我们把所有零件封装成一个 nn.Module,方便塞进完整的 Transformer:
class CausalSelfAttention(nn.Module):
"""
Decoder-only Transformer 的核心 attention 模块。
包含: Multi-Head + RoPE + Causal Mask
"""
def __init__(self, d_model, n_heads, max_seq_len=8192, rope_base=10000):
super().__init__()
assert d_model % n_heads == 0, "d_model 必须能被 n_heads 整除"
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
# QKV 三个投影合并成一次大 Linear(fused),减少一次 launch 开销
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
# RoPE 预计算(注册为 buffer,不参与训练但跟随 .to(device))
cos, sin = self._precompute_rope(self.d_head, max_seq_len, rope_base)
self.register_buffer("rope_cos", cos, persistent=False)
self.register_buffer("rope_sin", sin, persistent=False)
@staticmethod
def _precompute_rope(d_head, max_t, base):
inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head))
pos = torch.arange(max_t).float()
freqs = pos.unsqueeze(-1) * inv_freq.unsqueeze(0)
cos = freqs.cos().repeat_interleave(2, dim=-1)
sin = freqs.sin().repeat_interleave(2, dim=-1)
return cos, sin
@staticmethod
def _apply_rope(x, cos, sin):
x1, x2 = x[..., 0::2], x[..., 1::2]
x_rot = torch.stack([-x2, x1], dim=-1).flatten(-2)
return x * cos[None, None] + x_rot * sin[None, None]
def forward(self, x):
B, T, D = x.shape
H, Dh = self.n_heads, self.d_head
# 一次 Linear 算出 QKV
qkv = self.qkv_proj(x) # (B, T, 3D)
Q, K, V = qkv.split(D, dim=-1) # 每个 (B, T, D)
Q = Q.view(B, T, H, Dh).transpose(1, 2) # (B, H, T, Dh)
K = K.view(B, T, H, Dh).transpose(1, 2)
V = V.view(B, T, H, Dh).transpose(1, 2)
# RoPE
Q = self._apply_rope(Q, self.rope_cos[:T], self.rope_sin[:T])
K = self._apply_rope(K, self.rope_cos[:T], self.rope_sin[:T])
# Scaled dot-product
scores = Q @ K.transpose(-2, -1) / math.sqrt(Dh) # (B, H, T, T)
# Causal mask
mask = torch.triu(torch.full((T, T), float('-inf'), device=x.device), diagonal=1)
scores = scores + mask
attn = F.softmax(scores, dim=-1)
out = attn @ V # (B, H, T, Dh)
# Concat heads + output projection
out = out.transpose(1, 2).contiguous().view(B, T, D)
return self.out_proj(out)
总共 50 行左右(含注释)。这就是一个 production-grade 的 Causal Self-Attention with RoPE。
测试一下它能跑:
torch.manual_seed(42)
B, T, D, H = 2, 16, 128, 4
attn = CausalSelfAttention(d_model=D, n_heads=H, max_seq_len=64)
x = torch.randn(B, T, D)
out = attn(x)
print(out.shape) # torch.Size([2, 16, 128])
print(out.requires_grad) # True,可以反向传播
8.7 几个进阶工程考虑
Fused QKV 投影
我们用 nn.Linear(d_model, 3 * d_model) 把 Q/K/V 三个投影合并成一次。这是一个常见优化:
- 三次 GEMM 合并成一次:减少 GPU launch 次数
- 共享激活值:内存访问模式更连续
GPT-2、Llama 等主流模型的实现都用这种 fused 形式。但要注意:用 GQA / MQA 时,K 和 V 的输出维度小于 Q(因为它们的头数少),不能简单地用 3 * d_model——要分开处理:
# GQA 情况
self.q_proj = nn.Linear(d_model, n_heads * d_head, bias=False)
self.k_proj = nn.Linear(d_model, n_kv_heads * d_head, bias=False)
self.v_proj = nn.Linear(d_model, n_kv_heads * d_head, bias=False)
混合精度训练
实际工程几乎一定用混合精度(FP16 / BF16)训练:
attn = CausalSelfAttention(D, H).to(torch.bfloat16).cuda()
x = torch.randn(B, T, D, dtype=torch.bfloat16, device='cuda')
out = attn(x) # 全部在 BF16 下计算
但有一个数值稳定细节:softmax 一定要在 FP32 下做。因为 softmax 涉及 exp(),BF16 的精度可能让 exp 结果溢出或下溢。PyTorch 的 F.softmax 默认会自动 upcast 到 FP32 计算再 downcast,但如果你手动写,要注意:
scores_fp32 = scores.float() # cast 到 FP32
attn = F.softmax(scores_fp32, dim=-1)
attn = attn.to(scores.dtype) # cast 回原 dtype
用 PyTorch 内置的 scaled_dot_product_attention
PyTorch 2.0+ 提供了 F.scaled_dot_product_attention(SDPA)——它内置了 Flash Attention 等优化:
out = F.scaled_dot_product_attention(
Q, K, V,
attn_mask=None,
is_causal=True, # 自动应用 causal mask,不需要手动构造
dropout_p=0.0,
)
实际工程里,写完原理性的代码后,应该在生产时切换到 SDPA——它能自动选择 Flash Attention / Memory-Efficient Attention 等更快的实现。我们在第 18 章会讲 Flash Attention 内部到底怎么做的。
8.8 调试技巧
写 attention 时容易踩的几个坑,调试技巧如下:
坑 1:softmax 沿错维度
F.softmax(scores, dim=-1) 的 dim 必须是 K 维(最后一维)。如果 attention 矩阵形状是 (B, H, T_q, T_k),softmax 应该归一化 T_k 维。写错成 dim=-2 就完蛋。
调试方法:检查 attention 矩阵每一行的和是否为 1:
print(attn.sum(dim=-1)) # 应该全是 1.0
坑 2:mask 形状错位
mask 应该是 (T, T) 或可以 broadcast 到 (B, H, T, T) 的形状。如果你不小心写成了 (B, T) 这种形状(混淆了「padding mask」和「causal mask」),broadcast 会出意料的结果。
调试方法:手算一个小例子(比如 T=4),打印 mask 后的 score 矩阵,肉眼检查上三角是不是 -inf。
坑 3:view 和 transpose 顺序错
(B, T, D) → (B, T, H, Dh) 用 .view 是对的;(B, T, H, Dh) → (B, H, T, Dh) 必须用 .transpose,不能用 .view。原因:view 要求维度的展开顺序和内存连续顺序一致——你不能跳着重排维度。
调试方法:每次 reshape 后打印 stride:
print(x.stride())
print(x.is_contiguous())
坑 4:忘了 .contiguous()
transpose 之后 view 之前一定要 .contiguous(),否则 PyTorch 会报:
RuntimeError: view size is not compatible with input tensor's size and stride
坑 5:device / dtype 不一致
mask、cos、sin 这些预计算的张量要和 input 在同一个 device、同一个 dtype。register_buffer 是把它们「绑定」到模块上的标准做法——.to(device) 调用时会自动迁移。
8.9 用 unittest 做自一致性检验
最后给一个完整的测试脚本,验证我们的实现没问题:
def test_causal_self_attention():
torch.manual_seed(42)
B, T, D, H = 2, 16, 128, 4
attn = CausalSelfAttention(d_model=D, n_heads=H, max_seq_len=64)
x = torch.randn(B, T, D)
# 1. 形状测试
out = attn(x)
assert out.shape == (B, T, D), f"shape mismatch: {out.shape}"
# 2. 因果性测试: 截断输入后位置 0 输出应不变
out_full = attn(x)
out_truncated = attn(x[:, :8])
assert torch.allclose(out_full[:, :8], out_truncated, atol=1e-5)
# 3. 反向传播测试: 可以 backward 到 input
loss = out.sum()
loss.backward()
print("✓ 所有测试通过")
test_causal_self_attention()
这个测试覆盖了:
- 形状正确:输入输出形状一致
- 因果性正确:截断不变性(causal mask 起作用了)
- 反向传播畅通:所有 op 都可微
实际写一个新 attention 实现时,先写测试再写代码——这样能在 100 行代码里少走一周弯路。
8.10 这版代码与 Llama / GPT 的差别
读到这里你可能会问:这 80 行代码和真正的 Llama / GPT 的 attention 代码差多少?
差别其实不大。Llama 2 / Llama 3 的 attention 实现核心和我们的版本几乎一样,主要差异在:
- GQA(Grouped-Query Attention):K 和 V 的头数是 Q 的 1/8。在我们的代码里要把 Q 投影维度(
H * Dh)和 K/V 投影维度(H_kv * Dh)分开。 - Flash Attention:用
F.scaled_dot_product_attention或 xformers 的优化版替换我们手写的「scores → softmax → attn @ V」三步——内存效率高很多(第 18 章详解)。 - KV Cache:推理时缓存 K、V 而不是每次重算(第 15 章详解)。
- MLA(Multi-Head Latent Attention,DeepSeek-V2/V3):把 KV 压缩到一个低维 latent 空间——和 GQA 思路类似,但更激进。
但所有这些扩展都是在我们这版代码的基础上加东西,不会推翻这个核心。所以你掌握了第 8 章这版代码,就掌握了今天主流大模型 attention 实现的 90%。
剩下的 10%(KV Cache、Flash Attention)是第六部分(推理系统)的内容。
本章小结
- 从纸面到代码只需 12 行——naive attention 基本就是公式的逐行翻译。
- 多头的工程实现是 reshape 而不是 for 循环——所有头被打包到一次 batched matmul 里跑。
- causal mask 通过加上一个上三角
-inf实现——softmax 后这些位置贡献 0,相当于「看不见未来」。 - RoPE 旋转的是 Q 和 K,不是 V——预计算 cos/sin 表,每次前向 broadcast 应用。
- 封装成 nn.Module 后这是一个 production-grade 的实现——用 buffer 注册预计算张量、用 fused QKV 投影减少 launch 开销。
- 几个常见坑:softmax 沿错维度、mask 形状错位、view/transpose 顺序错、忘 contiguous、device/dtype 不一致。
- PyTorch 2.0+ 的 SDPA 内置 Flash Attention,生产环境推荐用它替换手写实现。
- 这版代码和真正的 Llama / GPT 几乎一致,差别在 GQA、Flash Attention、KV Cache 这些工程优化。
下一章我们用这套 attention 搭一个完整的 mini-GPT,加上 embedding、FFN、LayerNorm、训练循环——目标是从零训练一个能写古诗的小模型。
延伸阅读
- Andrej Karpathy, nanoGPT GitHub 仓库——本章 8.6 节的代码风格直接受其启发,强烈建议把整个仓库读一遍。
- PyTorch 文档:
torch.nn.functional.scaled_dot_product_attention——Flash Attention 的内置版本。 - xformers:
xformers.ops.memory_efficient_attention——更激进的内存优化版本。 - Llama 官方代码:
meta-llama/llama仓库的model.py——production 实现参考。 - 第 18 章本书后续:Flash Attention 内部到底怎么做的。