Vision Transformer简化实现与训练记录

本项目为我在学习 Vision Transformer(ViT)架构时的工程实践记录。目标是从零实现一个简化版的 ViT 模型,并在 CIFAR-10 上进行训练与推理。完整代码已上传至 GitHub:
🔗 https://github.com/QianQing26/MiniViT

1. 项目目的与背景

传统图像分类多基于 CNN,而 ViT 作为纯 Transformer 架构,移除了卷积模块,仅靠 attention 建模全局关系。为了深入理解 ViT 的机制,我尝试手动实现一个简化模型,并完成训练、评估、推理与可视化流程。


2. 模型结构设计(MiniViT)

模型遵循论文《An Image is Worth 16x16 Words》中的主干结构:

  • 输入图像划分为 patch,展平后进行线性映射;
  • 加入 [CLS] token 与可学习位置编码;
  • 多层标准 Transformer Encoder(包含 MHSA + MLP);
  • 使用 [CLS] token 的输出进行分类。

模型参数配置如下:

1
2
3
4
5
6
7
8
model = MiniViT(
img_size=32,
patch_size=4,
embed_dim=256,
depth=8,
heads=8,
num_classes=10
)

3. 项目结构与模块划分

1
2
3
4
5
6
7
8
9
10
11
12
MiniViT/
├── models/
│ ├── patch_embed.py # Patch Embedding 实现
│ ├── transformer.py # Attention、FFN、Transformer Block
│ └── vit.py # 主模型结构拼接
├── train.py # 包含训练 + 验证 + best model 保存
├── inference.py # 推理脚本,支持图像与预测可视化
├── utils.py # 工具函数:绘图、加载模型等
├── training_log.csv # 每轮训练记录日志
├── loss.png / acc.png # 训练曲线图
├── inference_result.png # 推理可视化结果图
└── README.md # 仓库介绍

4. 训练与验证记录

训练配置

  • 数据集:CIFAR-10
  • 优化器:AdamW(lr=3e-4)
  • batch_size:128
  • epoch 数:50
  • patch_size:4 × 4
  • embed_dim:256

最佳验证准确率

1
Epoch 50: Val Acc = 71.07%

训练过程中,训练 loss 与 acc 持续平稳提升,验证集表现稳定,整体无过拟合倾向。


5. 推理可视化

推理脚本 inference.py 支持随机抽样若干张测试图像,输出模型预测类别与真实标签。

  • 使用训练完成的 best_model.pt 权重
  • 支持 matplotlib 保存结果图像 inference_result.png
  • 错误预测以红色标题显示,正确为绿色

6. 技术细节与经验记录

  • 使用 tqdm 显示训练进度;
  • 多线程加载数据时需注意 Windows 平台下的 if __name__ == "__main__"
  • 解决 libiomp5md.dll 冲突:需在开头加入 os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
  • 使用 matplotlibpandas 保存与绘制训练曲线;
  • 训练中每轮记录训练 / 验证 loss 和 acc,并保存最优模型;
  • 推理时使用 softmax 输出置信度用于可视化结果更直观。

7. 项目地址

📦 GitHub 项目仓库:https://github.com/QianQing26/MiniViT