MiniViLT:简化ViLT实现
本文记录了我在 MiniViLT 项目 中,从零开始实现轻量级 ViLT(Vision-and-Language Transformer)的过程。 项目目标是:在 8GB 显存环境下,独立实现并训练一个小型多模态模型,支持图文匹配任务(ITM)。
1. 项目动机
原始 ViLT(ICML 2021)展示了一个非常简洁的思想:
不使用卷积特征提取器,也不依赖目标检测器,直接将图像划分为 patch,并与文本 token 拼接后输入 Transformer。
为了理解其核心机制,我决定自己动手实现一个简化版——MiniViLT,既保留核心结构,又满足资源受限环境。
2. 工程结构
整个项目采用模块化设计,目录结构如下:
1 | MiniViLT/ |
这种组织方式便于后续扩展(如添加 MLM、VQA 等任务)。
3. 数据准备
我选择 Flickr8k 数据集,它小而经典,适合显存有限的实验。
数据处理主要包括两步:
- 将
captions.txt
转换为 JSON 格式,每张图像对应 5 条描述; - 遍历所有 caption,统计词频并构建
vocab.json
。
示例代码(build_vocab.py
):
1 | counter = Counter() |
4. 模型实现
4.1 Patch Embedding
将图像切成 patch 并投影到向量空间:
1 | self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) |
4.2 Text Embedding
文本部分参考 BERT embedding:
1 | token_embed = nn.Embedding(vocab_size, embed_dim) |
4.3 多模态融合
拼接 [CLS] + image_patches + text_tokens
,然后送入 Transformer:
1 | fused = torch.cat([cls_token, image_embed, text_embed], dim=1) |
4.4 Transformer 编码器
Multi-Head Self-Attention + FFN + LayerNorm:
1 | scores = (q @ k.transpose(-2, -1)) / math.sqrt(d_k) |
5. 训练流程
任务:Image-Text Matching (ITM)
- 正样本:真实图文对
- 负样本:随机打乱 caption
1 | loss = nn.CrossEntropyLoss()(logits, labels) |
保存训练日志:
1 | with open("logs/training_log.csv", "a") as f: |
6. 训练效果
在 8GB 显存 + Flickr8k 上训练 3 epoch:
1 | epoch,loss,acc |
初期精度接近随机水平,但模型结构正确运行,后续可通过增大模型容量、调整学习率、增强负样本策略来提升性能。