从实现看理解:多头注意力机制
"We call our model the Transformer. The model architecture is shown in Figure 1."
—— Attention is All You Need, Vaswani et al., 2017
在 Transformer 中,多头注意力机制(Multi-Head Attention, MHA)是核心组件之一。论文中虽然只用了短短几页进行描述,但其实现中蕴含着大量工程智慧与数学原则。本文将结合我自己复现的transformer和 GPT 模型,站在一个研零初学者的视角上,从 mask 使用、softmax 数值稳定性 角度,学习《Attention is All You Need》中的一些小细节,并结合实际代码加以分析。
一段自己的实现
先来看一段我的实现代码:
1 | scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k) |
这几行代码常见于多头注意力实现,其背后隐含的思路与原始论文的公式(如下)完全一致:
cite:论文公式 (1)
一、三种 Mask 的语义角色
论文中在描述 Decoder 时提到:
“To prevent leftward information flow in the decoder to preserve the auto-regressive property, we apply a mask to the input...”
这说明在 Decoder 的 Self-Attention 中,mask 是结构必要条件。在工程中,我们通常使用三种 mask:
- Padding mask:忽略
<PAD>的影响; - Look-ahead mask:防止看到未来的 token;
- 组合 mask:前两种的结合。
示例代码:
1 | scores = scores.masked_fill(~mask, float("-inf")) |
这相当于把不应被注意的位置设为 -∞,使其 softmax 权重为 0。
注意:mask 的 shape 通常应为 [batch, 1, 1, seq_len] 或 [batch, head, seq_len, seq_len],否则将 silently fail!
二、softmax 前的除根号 dₖ 的必要性
“This scaling is necessary to counteract the effect of dot products growing large in magnitude...”
随着维度增长,dot-product 越来越大,softmax 输出趋于极端(接近 one-hot),梯度消失。因此我们除以
1 | scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k) |
亲测若省略这一步,模型学习将大幅恶化,训练过程可能出现 loss 为 NaN。
三、masked_fill 与softmax数值稳定性
从数值计算角度讲:
1 | attn = torch.softmax(scores, dim=-1) |
其中 scores 已被设为 -inf 的位置,在 softmax 中自动转化为 0。若设为较小负值(如 -1e9),虽然理论上也趋近 0,但会留下数值误差,甚至导致梯度传播问题。
因此,最好用 float('-inf'),PyTorch 内部对此进行了优化。
四、模块化设计中的注意力抽象
通过模块化设计,我们可以将注意力机制抽象为如下计算流程:
1 | Q, K, V -> attention score -> mask -> softmax -> dropout -> context vector |
这样的设计与论文逻辑对齐,也便于复用不同注意力策略(如 Sparse Attention、Causal Attention、Relative Position Attention 等)。
五、多头注意力的维度设计
“Instead of performing a single attention function with dmodel-dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values h times with different, learned linear projections...”
为什么要多头?
Transformer 并不使用一个大头(单个全维度 attention),而是将 d_model 拆分成多个小头,每个头可以关注输入的不同部分。这一设计背后有三个动因:
- 增强表达能力:每个注意力头可以学习不同的子空间;
- 并行高效计算:所有头共享相同的输入,但使用不同投影;
- 更少计算开销:小头维度
d_k = d_model / n_heads降低了 softmax 操作的复杂度。
维度设计示意
| 项目 | 维度(假设) | 说明 |
|---|---|---|
输入 x |
[B, T, d_model] |
Batch 大小为 B,序列长度为 T |
| Q/K/V 投影前 | [B, T, d_model] |
Embedding or上一层输出 |
| Q/K/V 投影后 | [B, h, T, d_k] |
多头展开后,每个头维度为 d_k |
| Attention输出 | [B, h, T, d_k] |
每头输出维度一致 |
| 合并后输出 | [B, T, d_model] |
拼接所有头,再做一次线性变换 |
其中:
,如论文中设定 d_model = 512,h = 8,则每头d_k = 64
代码实现:
1 | # 输入 [B, T, d_model] |
优点总结:
- 高效并行性:多头 attention 可以一次性计算所有 heads;
- 低内存占用:相比直接做
[B, T, T]的高维注意力,这种头部拆分方式更节省; - 性能更强:论文实验证明:多个小头的组合 > 一个大头;
- 语义多样性:不同头可捕捉不同的语法/语义关系,如某些关注主语,某些关注动词。
小总结
Transformer 的强大源于注意力机制的全局建模能力,但真正掌握它的工程实现,必须深入每一处看似微不足道的细节——如 mask 的逻辑维度、数值稳定性、softmax 、维度设计的配合。
这也是我将复现模型作为学习的一部分的初衷:只有亲手书写,才能真正理解为何如此实现、为何不能省略。
参考文献:
- Vaswani et al., “Attention is All You Need”, 2017
- https://arxiv.org/abs/1706.03762