Vision Transformer简化实现与训练记录
本项目为我在学习 Vision Transformer(ViT)架构时的工程实践记录。目标是从零实现一个简化版的 ViT 模型,并在 CIFAR-10 上进行训练与推理。完整代码已上传至 GitHub:
🔗 https://github.com/QianQing26/MiniViT
1. 项目目的与背景
传统图像分类多基于 CNN,而 ViT 作为纯 Transformer 架构,移除了卷积模块,仅靠 attention 建模全局关系。为了深入理解 ViT 的机制,我尝试手动实现一个简化模型,并完成训练、评估、推理与可视化流程。
2. 模型结构设计(MiniViT)
模型遵循论文《An Image is Worth 16x16 Words》中的主干结构:
- 输入图像划分为 patch,展平后进行线性映射;
- 加入
[CLS]
token 与可学习位置编码; - 多层标准 Transformer Encoder(包含 MHSA + MLP);
- 使用
[CLS]
token 的输出进行分类。
模型参数配置如下:
1 | model = MiniViT( |
3. 项目结构与模块划分
1 | MiniViT/ |
4. 训练与验证记录
训练配置
- 数据集:CIFAR-10
- 优化器:AdamW(lr=3e-4)
- batch_size:128
- epoch 数:50
- patch_size:4 × 4
- embed_dim:256
最佳验证准确率
1 | Epoch 50: Val Acc = 71.07% |
训练过程中,训练 loss 与 acc 持续平稳提升,验证集表现稳定,整体无过拟合倾向。
5. 推理可视化
推理脚本 inference.py
支持随机抽样若干张测试图像,输出模型预测类别与真实标签。
- 使用训练完成的
best_model.pt
权重 - 支持 matplotlib 保存结果图像
inference_result.png
- 错误预测以红色标题显示,正确为绿色
6. 技术细节与经验记录
- 使用
tqdm
显示训练进度; - 多线程加载数据时需注意 Windows 平台下的
if __name__ == "__main__"
; - 解决
libiomp5md.dll
冲突:需在开头加入os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
; - 使用
matplotlib
和pandas
保存与绘制训练曲线; - 训练中每轮记录训练 / 验证 loss 和 acc,并保存最优模型;
- 推理时使用 softmax 输出置信度用于可视化结果更直观。
7. 项目地址
📦 GitHub 项目仓库:https://github.com/QianQing26/MiniViT