多头注意力计算中的矩阵维度分析

本篇博客作为我的个人学习记录,详细梳理了标准实现与优化实现中的权重矩阵结构、维度变化过程,并结合我自己实现的 MultiHeadAttention 类深入分析了 PyTorch 中一些关键函数的作用,包括 view()reshape()contiguous() 等在实现过程中的实际含义与区别。


一、标准多头注意力的权重矩阵结构

1. 多组权重矩阵(直观方式)

每个注意力头各自拥有独立的查询(Wq)、键(Wk)和值(Wv)投影矩阵: 这种做法直观但效率低,不利于并行运算和矩阵加速。

2. 输出拼接与再投影

多个注意力头的输出拼接后通过一个线性变换矩阵:


二、优化实现:共享大矩阵 + reshape 拆头

1. 权重矩阵合并策略

将所有头的 Wq 合并成一个大矩阵:

1
self.q_proj = nn.Linear(d_model, d_model)  # 实际 shape: [d_model, h * d_k]

类似地,k_projv_proj 也是如此。这样一次前向就可以得到所有头的投影结果,再用 .view().transpose() 拆解为多个头。

2. reshape 的拆分步骤

1
Q = self.q_proj(x).view(batch, seq_len, n_heads, d_k).transpose(1, 2)

这一步完成从 [B, L, d_model][B, n_heads, L, d_k] 的重排。


三、维度变化全过程(结合代码逐步分析)

输入

1
x: [batch, seq_len, d_model]

投影 + reshape + transpose

1
2
Q = self.q_proj(x)  # [B, L, d_model]
Q = Q.view(B, L, h, d_k).transpose(1, 2) # [B, h, L, d_k]
  • view():只改变张量的形状,不改变内存;
  • transpose(1, 2):交换维度,用于满足后续矩阵乘法的维度要求;
  • 注意:使用 view() 前通常要 .contiguous(),确保张量在内存中是连续的;否则可能出错。

注意力得分计算

1
2
scores = Q @ K.transpose(-2, -1) / sqrt(d_k)
# => [B, h, L, L]

Mask、softmax、dropout(标准注意力机制流程)

注意力输出与合并

1
2
out = attn @ V  # [B, h, L, d_k]
out = out.transpose(1, 2).contiguous().view(B, L, d_model)
  • contiguous():把转置后的张量重新存储为连续内存;
  • view(B, L, d_model):恢复拼接后的整体表示。

输出线性变换

1
return self.out_proj(out)  # [B, L, d_model]

四、为何采用共享投影矩阵实现?

相比传统方式,合并矩阵 + reshape 的方式有诸多优势:

特性 优化实现 多组权重实现
并行计算效率 ✅ 更高 ❌ 多次小矩阵计算
参数共享性 ✅ 模块统一 ❌ 不便扩展
内存结构优化 ✅ 使用连续大矩阵 ❌ 零散小块

五、实用函数总结

  • view():高效改变张量形状,但要求数据是连续的;
  • reshape():比 view() 更灵活,能自动处理非连续张量,但略慢;
  • contiguous():生成内存连续的副本,常与 .view() 连用;
  • transpose():交换维度,但不改变内存顺序;

可行的实践顺序:

1
x.transpose(...).contiguous().view(...)