视觉-语言预训练
视觉-语言预训练(Vision-Language Pretraining, VLP)旨在通过大规模图文数据学习多模态表示,为下游视觉-语言任务提供强大的基础模型。
任务分类
三大学习范式
┌─────────────────────────────────────────────────────────────────┐
│ VLP 学习范式 │
├──────────────────┬──────────────────┬──────────────────────────┤
│ 对比学习 │ 生成学习 │ 混合学习 │
│ (Contrastive) │ (Generative) │ (Hybrid) │
├──────────────────┼──────────────────┼──────────────────────────┤
│ CLIP, ALIGN │ VL-T5, SimVLM │ Flamingo, BLIP-2 │
│ 学习对齐表示 │ 学习生成能力 │ 结合多种目标 │
└──────────────────┴──────────────────┴──────────────────────────┘
任务详解
对比学习任务
| 任务 | 描述 | 损失函数 |
|---|---|---|
| ITC | 图文对比 | InfoNCE |
| ITM | 图文匹配 | 二分类交叉熵 |
生成学习任务
| 任务 | 描述 | 损失函数 |
|---|---|---|
| ITG | 图像描述生成 | LM loss |
| MLM | 多模态掩码语言建模 | MLM loss |
| MRM | 掩码区域建模 | RMSE/CE |
混合学习
| 方法 | 组合方式 |
|---|---|
| UNITER | ITC + ITM + MLM + MRM |
| ViLT | ITC + MLM |
| BLIP | ITC + ITM + LM |
预训练目标详解
图文对比(Image-Text Contrastive, ITC)
学习将匹配的图像和文本映射到相近的表示空间:
def itc_loss(image_features, text_features, temperature=0.07):
"""
ITC损失:最大化匹配对相似度,最小化不匹配对相似度
"""
# L2归一化
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
# 相似度矩阵
logits = image_features @ text_features.t() / temperature
# Labels: 对角线为正样本
batch_size = image_features.shape[0]
labels = torch.arange(batch_size, device=image_features.device)
# 对称损失
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.t(), labels)
return (loss_i2t + loss_t2i) / 2详见 对比学习。
图文匹配(Image-Text Matching, ITM)
判断图像和文本是否匹配,是细粒度的二分类任务:
class ITMHead(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.fc = nn.Linear(hidden_dim * 2, 2) # 二分类
def forward(self, image_feat, text_feat, is_matched=None):
# 融合多模态特征
combined = torch.cat([image_feat, text_feat], dim=-1)
logits = self.fc(combined)
if is_matched is not None:
loss = F.cross_entropy(logits, is_matched.long())
return loss, logits
return logits掩码语言建模(Masked Language Modeling, MLM)
遮盖文本中的词元,用视觉信息来预测:
def mlm_loss(text_features, masked_ids, vision_features, vocab_size):
"""
MLM: 根据视觉特征预测被遮盖的文本词元
"""
# 文本解码器预测masked位置的词
predictions = text_decoder(masked_ids, vision_features)
# 仅计算masked位置的损失
loss = F.cross_entropy(
predictions.view(-1, vocab_size),
masked_labels.view(-1)
)
return loss掩码区域建模(Masked Region Modeling, MRM)
遮盖图像中的区域,用文本信息来预测:
def mrm_loss(image_features, masked_regions, text_features):
"""
MRM: 根据文本特征预测被遮盖的图像区域
"""
# 方法1: 区域特征回归 (MRFR)
masked_feat_pred = mr_decoder(image_features, text_features)
loss_fr = F.mse_loss(masked_feat_pred, masked_regions)
# 方法2: 区域分类 (MRC)
region_labels = roi_head(masked_regions) # 预测类别
region_cls_pred = mr_head(image_features)
loss_cls = F.cross_entropy(region_cls_pred, region_labels)
return loss_fr + loss_cls经典架构
UNITER
UNITER(UNiversal Image-TExt Representation)是早期VLP的代表模型:
class UNITER(nn.Module):
def __init__(self, config):
super().__init__()
self.image_encoder = ResNet50() # 或 ViT
self.text_encoder = BERT()
self.fusion = CrossEncoder(hidden_dim=768, layers=6)
# 任务头
self.itm_head = ITMHead(768)
self.mlm_head = MLMHead(768, vocab_size=30522)
self.mrm_head = MRMHead(768, num_classes=1601)
def forward(self, images, input_ids, attention_mask,
masked_ids=None, region_labels=None):
# 独立编码
img_feat = self.image_encoder(images)
txt_feat = self.text_encoder(input_ids, attention_mask)
# 融合编码
fused = self.fusion(img_feat, txt_feat)
# 多任务损失
losses = {}
losses['itm'] = self.itm_head(fused, is_matched)
losses['mlm'] = self.mlm_head(fused, masked_ids)
losses['mrm'] = self.mrm_head(fused, region_labels)
return sum(losses.values())ViLT
ViLT(Vision-and-Language Transformer)是最早将Transformer统一用于VLP的模型之一:
class ViLT(nn.Module):
def __init__(self, config):
super().__init__()
# 模态共享的Transformer
self.transformer = Transformer(
hidden_dim=256,
num_layers=12,
num_heads=8
)
# 线性投影(轻量级)
self.patch_proj = nn.Linear(768, 256) # ViT-B/32
self.word_proj = nn.Linear(768, 256) # BERT
self.cls_token = nn.Parameter(torch.zeros(1, 1, 256))
# 任务头
self.itm_head = nn.Linear(256, 2)
self.mlm_head = nn.Linear(256, 30522)
def forward(self, images, input_ids, attention_mask):
# 投影
img_emb = self.patch_proj(image_features) # (B, num_patches, 256)
txt_emb = self.word_proj(text_features) # (B, seq_len, 256)
# 拼接
cls_token = self.cls_token.expand(B, -1, -1)
embeds = torch.cat([cls_token, img_emb, txt_emb], dim=1)
# Transformer
output = self.transformer(embeds)
# 任务预测
itm_logits = self.itm_head(output[:, 0])
mlm_logits = self.mlm_head(output[:, 1:seq_len+1])
return itm_logits, mlm_logitsViLT vs 早期模型:
| 模型 | 视觉编码器 | 融合方式 | 参数量 |
|---|---|---|---|
| ViLBERT | ResNet + 双流 | 迟融合 | 大 |
| UNITER | ResNet/ViT + 双流 | 迟融合 | 大 |
| ViLT | ViT + 单流 | 早融合 | 小 |
BLIP
BLIP(Bootstrapping Language-Image Pre-training)创新性地统一了理解和生成任务:
class BLIP(nn.Module):
def __init__(self, config):
super().__init__()
self.visual_encoder = ViT(image_size=224)
self.text_encoder = BERT()
self.text_decoder = GPT2() # 因果解码器
# Q-Former: 跨模态注意力瓶颈
self.Qformer = QFormer(hidden_dim=768, num_heads=12, num_layers=6)
# 投影
self.vision_proj = nn.Linear(768, 256)
self.text_proj = nn.Linear(768, 256)
self.itm_head = nn.Linear(768, 2)
def forward_text(self, image_embeds, input_ids, attention_mask):
"""理解任务:ITC + ITM + MLM"""
# Q-Former提取query特征
query_output = self.Qformer(
query_embeds=image_embeds,
input_ids=input_ids,
attention_mask=attention_mask
)
# ITC
image_feat = self.vision_proj(query_output)
text_feat = self.text_proj(self.text_encoder(...))
itc_loss = compute_itc(image_feat, text_feat)
# MLM
mlm_loss = compute_mlm(query_output, masked_ids)
return itc_loss + mlm_loss
def forward_image_caption(self, image_embeds, input_ids):
"""生成任务:图像描述"""
# Q-Former处理图像
query_output = self.Qformer(query_embeds=image_embeds)
# 解码器生成文本
output = self.text_decoder(
input_ids=input_ids,
encoder_hidden_states=query_output
)
lm_loss = F.cross_entropy(output, labels)
return lm_lossBLIP-2:轻量级多模态预训练
BLIP-2的核心创新是使用Q-Former作为模态桥接:
┌──────────┐ ┌──────────┐ ┌──────────┐
│ CLIP │ ←→ │ Q-Former │ ←→ │ LLM │
│ (冻结) │ │ (训练) │ │ (冻结) │
└──────────┘ └──────────┘ └──────────┘
视觉特征 跨模态学习 语言生成
class BLIP2(nn.Module):
def __init__(self, config):
super().__init__()
# 冻结视觉编码器
self.visual_encoder = load_pretrained_clip()
for param in self.visual_encoder.parameters():
param.requires_grad = False
# Q-Former: 可学习
self.Qformer = QFormer(
hidden_dim=768,
num_query_tokens=32, # 瓶颈维度
num_heads=12,
num_layers=6
)
# 冻结LLM
self.llm = load_pretrained_llm()
for param in self.llm.parameters():
param.requires_grad = False
# 投影层
self.llm_proj = nn.Linear(768, llm.hidden_size)
def forward(self, images, input_ids, training_mode="pretrain"):
# 视觉编码
image_embeds = self.visual_encoder(images)
# Q-Former: 提取query表示
query_output = self.Qformer(
query_embeds=image_embeds,
num_query_tokens=32
)
# 投影到LLM空间
query_embeds = self.llm_proj(query_output)
if training_mode == "pretrain":
# 预训练阶段:使用语言建模损失
outputs = self.llm(
inputs_embeds=query_embeds,
labels=input_ids
)
return outputs.loss
else:
# 推理阶段:直接生成
return self.llm.generate(inputs_embeds=query_embeds)预训练数据集
常用数据集
| 数据集 | 规模 | 类型 |
|---|---|---|
| CC3M | 3M | 图文对 |
| CC12M | 12M | 图文对 |
| SBU | 1M | 图文对 |
| COCO | 113K | 图像+描述+检测 |
| VG | 100K | 图像+描述+关系 |
| LAION-2B | 2B+ | 大规模图文对 |
| LAION-5B | 5B+ | 超大规模 |
数据过滤
def filter_image_text_pairs(images, texts, min_length=3, max_length=77):
"""过滤低质量图文对"""
filtered_pairs = []
for img, txt in zip(images, texts):
# 文本长度过滤
if not (min_length <= len(txt.split()) <= max_length):
continue
# 图像质量检查(可选)
if img.width < 224 or img.height < 224:
continue
filtered_pairs.append((img, txt))
return filtered_pairs下游任务迁移
零样本迁移
预训练模型可以直接用于零样本任务:
# CLIP风格的零样本分类
def zero_shot_classification(model, image, class_names):
# 编码
image_feat = model.encode_image(image)
text_feats = [model.encode_text(f"a photo of a {c}") for c in class_names]
# 计算相似度
similarities = [cosine_sim(image_feat, tf) for tf in text_feats]
return class_names[argmax(similarities)]微调迁移
def finetune_vlp(model, train_data, num_epochs=10):
"""下游任务微调"""
# 解冻特定层
for name, param in model.named_parameters():
if "text_encoder.layer.11" in name or "vision_encoder.layer.11" in name:
param.requires_grad = True
else:
param.requires_grad = False
# 训练
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-5
)
for epoch in range(num_epochs):
for batch in train_data:
loss = model(**batch)
loss.backward()
optimizer.step()与现有内容的衔接
| 关联 | 内容 |
|---|---|
| CLIP | 对比学习基础 |
| Transformer | 视觉Transformer架构 |
| 表达能力 | 多模态表示学习理论 |
| MoE | 多模态中的稀疏专家 |