MiniViLT:简化ViLT实现

本文记录了我在 MiniViLT 项目 中,从零开始实现轻量级 ViLT(Vision-and-Language Transformer)的过程。 项目目标是:在 8GB 显存环境下,独立实现并训练一个小型多模态模型,支持图文匹配任务(ITM)。


1. 项目动机

原始 ViLT(ICML 2021)展示了一个非常简洁的思想:

不使用卷积特征提取器,也不依赖目标检测器,直接将图像划分为 patch,并与文本 token 拼接后输入 Transformer。

为了理解其核心机制,我决定自己动手实现一个简化版——MiniViLT,既保留核心结构,又满足资源受限环境。


2. 工程结构

整个项目采用模块化设计,目录结构如下:

1
2
3
4
5
6
7
8
MiniViLT/
├── data/ # 数据处理 & 数据集
├── models/ # 模型模块
├── utils/ # 工具函数
├── checkpoints/ # 训练保存
├── logs/ # 训练日志
├── train.py # 训练入口
└── README.md

这种组织方式便于后续扩展(如添加 MLM、VQA 等任务)。


3. 数据准备

我选择 Flickr8k 数据集,它小而经典,适合显存有限的实验。

数据处理主要包括两步:

  1. captions.txt 转换为 JSON 格式,每张图像对应 5 条描述;
  2. 遍历所有 caption,统计词频并构建 vocab.json

示例代码(build_vocab.py):

1
2
3
4
5
6
7
8
9
10
counter = Counter()
for sample in data:
for caption in sample["captions"]:
tokens = caption.lower().split()
counter.update(tokens)

vocab = {"[PAD]": 0, "[UNK]": 1}
for word, freq in counter.items():
if freq >= 2:
vocab[word] = len(vocab)

4. 模型实现

4.1 Patch Embedding

将图像切成 patch 并投影到向量空间:

1
2
3
self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
x = self.proj(img) # [B, D, H', W']
x = x.flatten(2).transpose(1,2) # [B, N, D]

4.2 Text Embedding

文本部分参考 BERT embedding:

1
2
3
token_embed = nn.Embedding(vocab_size, embed_dim)
pos_embed = nn.Parameter(torch.randn(max_len, embed_dim))
type_embed = nn.Parameter(torch.randn(1, embed_dim))

4.3 多模态融合

拼接 [CLS] + image_patches + text_tokens,然后送入 Transformer:

1
2
3
4
fused = torch.cat([cls_token, image_embed, text_embed], dim=1)
out = transformer(fused)
cls = out[:, 0]
logits = itm_head(cls) # [B, 2]

4.4 Transformer 编码器

Multi-Head Self-Attention + FFN + LayerNorm

1
2
3
scores = (q @ k.transpose(-2, -1)) / math.sqrt(d_k)
attn = scores.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, N, D)

5. 训练流程

任务:Image-Text Matching (ITM)

  • 正样本:真实图文对
  • 负样本:随机打乱 caption
1
2
loss = nn.CrossEntropyLoss()(logits, labels)
acc = (logits.argmax(dim=1) == labels).float().mean()

保存训练日志:

1
2
3
with open("logs/training_log.csv", "a") as f:
writer = csv.writer(f)
writer.writerow([epoch, avg_loss, acc])

6. 训练效果

在 8GB 显存 + Flickr8k 上训练 3 epoch:

1
2
3
4
epoch,loss,acc
1,0.7086,0.5008
2,0.6994,0.5105
3,0.6975,0.5046

初期精度接近随机水平,但模型结构正确运行,后续可通过增大模型容量、调整学习率、增强负样本策略来提升性能。


7. 项目地址

👉 GitHub: MiniViLT