从实现看理解:多头注意力机制

"We call our model the Transformer. The model architecture is shown in Figure 1."
—— Attention is All You Need, Vaswani et al., 2017

在 Transformer 中,多头注意力机制(Multi-Head Attention, MHA)是核心组件之一。论文中虽然只用了短短几页进行描述,但其实现中蕴含着大量工程智慧与数学原则。本文将结合我自己复现的transformer和 GPT 模型,站在一个研零初学者的视角上,从 mask 使用softmax 数值稳定性 角度,学习《Attention is All You Need》中的一些小细节,并结合实际代码加以分析。

一段自己的实现

先来看一段我的实现代码:

1
2
3
4
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = scores.masked_fill(~mask, float("-inf"))
attn = torch.softmax(scores, dim=-1)
attn = self.dropout(attn)

这几行代码常见于多头注意力实现,其背后隐含的思路与原始论文的公式(如下)完全一致:

cite:论文公式 (1)


一、三种 Mask 的语义角色

论文中在描述 Decoder 时提到:

“To prevent leftward information flow in the decoder to preserve the auto-regressive property, we apply a mask to the input...”

这说明在 Decoder 的 Self-Attention 中,mask 是结构必要条件。在工程中,我们通常使用三种 mask:

  1. Padding mask:忽略 <PAD> 的影响;
  2. Look-ahead mask:防止看到未来的 token;
  3. 组合 mask:前两种的结合。

示例代码:

1
scores = scores.masked_fill(~mask, float("-inf"))

这相当于把不应被注意的位置设为 -∞,使其 softmax 权重为 0。

注意:mask 的 shape 通常应为 [batch, 1, 1, seq_len][batch, head, seq_len, seq_len],否则将 silently fail!


二、softmax 前的除根号 dₖ 的必要性

“This scaling is necessary to counteract the effect of dot products growing large in magnitude...”

随着维度增长,dot-product 越来越大,softmax 输出趋于极端(接近 one-hot),梯度消失。因此我们除以 ,使得 softmax 有效范围保持稳定。

1
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)

亲测若省略这一步,模型学习将大幅恶化,训练过程可能出现 loss 为 NaN。


三、masked_fill 与softmax数值稳定性

从数值计算角度讲:

1
attn = torch.softmax(scores, dim=-1)

其中 scores 已被设为 -inf 的位置,在 softmax 中自动转化为 0。若设为较小负值(如 -1e9),虽然理论上也趋近 0,但会留下数值误差,甚至导致梯度传播问题。

因此,最好用 float('-inf'),PyTorch 内部对此进行了优化。


四、模块化设计中的注意力抽象

通过模块化设计,我们可以将注意力机制抽象为如下计算流程:

1
Q, K, V -> attention score -> mask -> softmax -> dropout -> context vector

这样的设计与论文逻辑对齐,也便于复用不同注意力策略(如 Sparse Attention、Causal Attention、Relative Position Attention 等)。


五、多头注意力的维度设计

“Instead of performing a single attention function with dmodel-dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values h times with different, learned linear projections...”


为什么要多头?

Transformer 并不使用一个大头(单个全维度 attention),而是将 d_model 拆分成多个小头,每个头可以关注输入的不同部分。这一设计背后有三个动因:

  1. 增强表达能力:每个注意力头可以学习不同的子空间;
  2. 并行高效计算:所有头共享相同的输入,但使用不同投影;
  3. 更少计算开销:小头维度 d_k = d_model / n_heads 降低了 softmax 操作的复杂度。

维度设计示意

项目 维度(假设) 说明
输入 x [B, T, d_model] Batch 大小为 B,序列长度为 T
Q/K/V 投影前 [B, T, d_model] Embedding or上一层输出
Q/K/V 投影后 [B, h, T, d_k] 多头展开后,每个头维度为 d_k
Attention输出 [B, h, T, d_k] 每头输出维度一致
合并后输出 [B, T, d_model] 拼接所有头,再做一次线性变换

其中:

  • ,如论文中设定 d_model = 512h = 8,则每头 d_k = 64

代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 输入 [B, T, d_model]
Q = self.q_proj(query) # -> [B, T, d_model]
K = self.k_proj(key) # -> [B, T, d_model]
V = self.v_proj(value) # -> [B, T, d_model]

# 拆成多头,变为 [B, h, T, d_k]
Q = Q.view(B, T, h, d_k).transpose(1, 2)
K = K.view(B, T, h, d_k).transpose(1, 2)
V = V.view(B, T, h, d_k).transpose(1, 2)

# 注意力计算与拼接输出
scores = (Q @ K.transpose(-2, -1)) / sqrt(d_k) # [B, h, T, T]
attn = softmax(scores)
output = attn @ V # [B, h, T, d_k]
output = output.transpose(1, 2).reshape(B, T, d_model)

优点总结:

  • 高效并行性:多头 attention 可以一次性计算所有 heads;
  • 低内存占用:相比直接做 [B, T, T] 的高维注意力,这种头部拆分方式更节省;
  • 性能更强:论文实验证明:多个小头的组合 > 一个大头;
  • 语义多样性:不同头可捕捉不同的语法/语义关系,如某些关注主语,某些关注动词。

小总结

Transformer 的强大源于注意力机制的全局建模能力,但真正掌握它的工程实现,必须深入每一处看似微不足道的细节——如 mask 的逻辑维度、数值稳定性、softmax 、维度设计的配合。

这也是我将复现模型作为学习的一部分的初衷:只有亲手书写,才能真正理解为何如此实现、为何不能省略。


参考文献:

  • Vaswani et al., “Attention is All You Need”, 2017
  • https://arxiv.org/abs/1706.03762