本篇博客作为我的个人学习记录,详细梳理了标准实现与优化实现中的权重矩阵结构、维度变化过程,并结合我自己实现的 MultiHeadAttention
类深入分析了 PyTorch 中一些关键函数的作用,包括 view()
、reshape()
、contiguous()
等在实现过程中的实际含义与区别。
一、标准多头注意力的权重矩阵结构
1. 多组权重矩阵(直观方式)
每个注意力头各自拥有独立的查询(Wq)、键(Wk)和值(Wv)投影矩阵: 这种做法直观但效率低,不利于并行运算和矩阵加速。
2. 输出拼接与再投影
多个注意力头的输出拼接后通过一个线性变换矩阵:
二、优化实现:共享大矩阵 + reshape 拆头
1. 权重矩阵合并策略
将所有头的 Wq 合并成一个大矩阵:
1
| self.q_proj = nn.Linear(d_model, d_model)
|
类似地,k_proj
和 v_proj
也是如此。这样一次前向就可以得到所有头的投影结果,再用 .view()
和 .transpose()
拆解为多个头。
2. reshape 的拆分步骤
1
| Q = self.q_proj(x).view(batch, seq_len, n_heads, d_k).transpose(1, 2)
|
这一步完成从 [B, L, d_model]
到 [B, n_heads, L, d_k]
的重排。
三、维度变化全过程(结合代码逐步分析)
输入
1
| x: [batch, seq_len, d_model]
|
投影 + reshape + transpose
1 2
| Q = self.q_proj(x) Q = Q.view(B, L, h, d_k).transpose(1, 2)
|
view()
:只改变张量的形状,不改变内存;
transpose(1, 2)
:交换维度,用于满足后续矩阵乘法的维度要求;
- 注意:使用
view()
前通常要 .contiguous()
,确保张量在内存中是连续的;否则可能出错。
注意力得分计算
1 2
| scores = Q @ K.transpose(-2, -1) / sqrt(d_k)
|
Mask、softmax、dropout(标准注意力机制流程)
注意力输出与合并
1 2
| out = attn @ V out = out.transpose(1, 2).contiguous().view(B, L, d_model)
|
contiguous()
:把转置后的张量重新存储为连续内存;
view(B, L, d_model)
:恢复拼接后的整体表示。
输出线性变换
1
| return self.out_proj(out)
|
四、为何采用共享投影矩阵实现?
相比传统方式,合并矩阵 + reshape 的方式有诸多优势:
并行计算效率 |
✅ 更高 |
❌ 多次小矩阵计算 |
参数共享性 |
✅ 模块统一 |
❌ 不便扩展 |
内存结构优化 |
✅ 使用连续大矩阵 |
❌ 零散小块 |
五、实用函数总结
view()
:高效改变张量形状,但要求数据是连续的;
reshape()
:比 view()
更灵活,能自动处理非连续张量,但略慢;
contiguous()
:生成内存连续的副本,常与 .view()
连用;
transpose()
:交换维度,但不改变内存顺序;
可行的实践顺序:
1
| x.transpose(...).contiguous().view(...)
|