视觉-语言预训练

视觉-语言预训练(Vision-Language Pretraining, VLP)旨在通过大规模图文数据学习多模态表示,为下游视觉-语言任务提供强大的基础模型。

任务分类

三大学习范式

┌─────────────────────────────────────────────────────────────────┐
│                        VLP 学习范式                              │
├──────────────────┬──────────────────┬──────────────────────────┤
│   对比学习        │    生成学习        │      混合学习            │
│  (Contrastive)   │   (Generative)   │    (Hybrid)             │
├──────────────────┼──────────────────┼──────────────────────────┤
│ CLIP, ALIGN      │ VL-T5, SimVLM    │ Flamingo, BLIP-2        │
│ 学习对齐表示      │ 学习生成能力       │ 结合多种目标             │
└──────────────────┴──────────────────┴──────────────────────────┘

任务详解

对比学习任务

任务描述损失函数
ITC图文对比InfoNCE
ITM图文匹配二分类交叉熵

生成学习任务

任务描述损失函数
ITG图像描述生成LM loss
MLM多模态掩码语言建模MLM loss
MRM掩码区域建模RMSE/CE

混合学习

方法组合方式
UNITERITC + ITM + MLM + MRM
ViLTITC + MLM
BLIPITC + ITM + LM

预训练目标详解

图文对比(Image-Text Contrastive, ITC)

学习将匹配的图像和文本映射到相近的表示空间:

def itc_loss(image_features, text_features, temperature=0.07):
    """
    ITC损失:最大化匹配对相似度,最小化不匹配对相似度
    """
    # L2归一化
    image_features = F.normalize(image_features, dim=-1)
    text_features = F.normalize(text_features, dim=-1)
    
    # 相似度矩阵
    logits = image_features @ text_features.t() / temperature
    
    # Labels: 对角线为正样本
    batch_size = image_features.shape[0]
    labels = torch.arange(batch_size, device=image_features.device)
    
    # 对称损失
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.t(), labels)
    
    return (loss_i2t + loss_t2i) / 2

详见 对比学习

图文匹配(Image-Text Matching, ITM)

判断图像和文本是否匹配,是细粒度的二分类任务:

class ITMHead(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.fc = nn.Linear(hidden_dim * 2, 2)  # 二分类
    
    def forward(self, image_feat, text_feat, is_matched=None):
        # 融合多模态特征
        combined = torch.cat([image_feat, text_feat], dim=-1)
        logits = self.fc(combined)
        
        if is_matched is not None:
            loss = F.cross_entropy(logits, is_matched.long())
            return loss, logits
        return logits

掩码语言建模(Masked Language Modeling, MLM)

遮盖文本中的词元,用视觉信息来预测:

def mlm_loss(text_features, masked_ids, vision_features, vocab_size):
    """
    MLM: 根据视觉特征预测被遮盖的文本词元
    """
    # 文本解码器预测masked位置的词
    predictions = text_decoder(masked_ids, vision_features)
    
    # 仅计算masked位置的损失
    loss = F.cross_entropy(
        predictions.view(-1, vocab_size),
        masked_labels.view(-1)
    )
    return loss

掩码区域建模(Masked Region Modeling, MRM)

遮盖图像中的区域,用文本信息来预测:

def mrm_loss(image_features, masked_regions, text_features):
    """
    MRM: 根据文本特征预测被遮盖的图像区域
    """
    # 方法1: 区域特征回归 (MRFR)
    masked_feat_pred = mr_decoder(image_features, text_features)
    loss_fr = F.mse_loss(masked_feat_pred, masked_regions)
    
    # 方法2: 区域分类 (MRC)
    region_labels = roi_head(masked_regions)  # 预测类别
    region_cls_pred = mr_head(image_features)
    loss_cls = F.cross_entropy(region_cls_pred, region_labels)
    
    return loss_fr + loss_cls

经典架构

UNITER

UNITER(UNiversal Image-TExt Representation)是早期VLP的代表模型:

class UNITER(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.image_encoder = ResNet50()  # 或 ViT
        self.text_encoder = BERT()
        self.fusion = CrossEncoder(hidden_dim=768, layers=6)
        
        # 任务头
        self.itm_head = ITMHead(768)
        self.mlm_head = MLMHead(768, vocab_size=30522)
        self.mrm_head = MRMHead(768, num_classes=1601)
    
    def forward(self, images, input_ids, attention_mask, 
                masked_ids=None, region_labels=None):
        # 独立编码
        img_feat = self.image_encoder(images)
        txt_feat = self.text_encoder(input_ids, attention_mask)
        
        # 融合编码
        fused = self.fusion(img_feat, txt_feat)
        
        # 多任务损失
        losses = {}
        losses['itm'] = self.itm_head(fused, is_matched)
        losses['mlm'] = self.mlm_head(fused, masked_ids)
        losses['mrm'] = self.mrm_head(fused, region_labels)
        
        return sum(losses.values())

ViLT

ViLT(Vision-and-Language Transformer)是最早将Transformer统一用于VLP的模型之一:

class ViLT(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 模态共享的Transformer
        self.transformer = Transformer(
            hidden_dim=256,
            num_layers=12,
            num_heads=8
        )
        
        # 线性投影(轻量级)
        self.patch_proj = nn.Linear(768, 256)  # ViT-B/32
        self.word_proj = nn.Linear(768, 256)   # BERT
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, 256))
        
        # 任务头
        self.itm_head = nn.Linear(256, 2)
        self.mlm_head = nn.Linear(256, 30522)
    
    def forward(self, images, input_ids, attention_mask):
        # 投影
        img_emb = self.patch_proj(image_features)  # (B, num_patches, 256)
        txt_emb = self.word_proj(text_features)     # (B, seq_len, 256)
        
        # 拼接
        cls_token = self.cls_token.expand(B, -1, -1)
        embeds = torch.cat([cls_token, img_emb, txt_emb], dim=1)
        
        # Transformer
        output = self.transformer(embeds)
        
        # 任务预测
        itm_logits = self.itm_head(output[:, 0])
        mlm_logits = self.mlm_head(output[:, 1:seq_len+1])
        
        return itm_logits, mlm_logits

ViLT vs 早期模型

模型视觉编码器融合方式参数量
ViLBERTResNet + 双流迟融合
UNITERResNet/ViT + 双流迟融合
ViLTViT + 单流早融合

BLIP

BLIP(Bootstrapping Language-Image Pre-training)创新性地统一了理解和生成任务:

class BLIP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.visual_encoder = ViT(image_size=224)
        self.text_encoder = BERT()
        self.text_decoder = GPT2()  # 因果解码器
        
        # Q-Former: 跨模态注意力瓶颈
        self.Qformer = QFormer(hidden_dim=768, num_heads=12, num_layers=6)
        
        # 投影
        self.vision_proj = nn.Linear(768, 256)
        self.text_proj = nn.Linear(768, 256)
        self.itm_head = nn.Linear(768, 2)
    
    def forward_text(self, image_embeds, input_ids, attention_mask):
        """理解任务:ITC + ITM + MLM"""
        # Q-Former提取query特征
        query_output = self.Qformer(
            query_embeds=image_embeds,
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # ITC
        image_feat = self.vision_proj(query_output)
        text_feat = self.text_proj(self.text_encoder(...))
        itc_loss = compute_itc(image_feat, text_feat)
        
        # MLM
        mlm_loss = compute_mlm(query_output, masked_ids)
        
        return itc_loss + mlm_loss
    
    def forward_image_caption(self, image_embeds, input_ids):
        """生成任务:图像描述"""
        # Q-Former处理图像
        query_output = self.Qformer(query_embeds=image_embeds)
        
        # 解码器生成文本
        output = self.text_decoder(
            input_ids=input_ids,
            encoder_hidden_states=query_output
        )
        
        lm_loss = F.cross_entropy(output, labels)
        return lm_loss

BLIP-2:轻量级多模态预训练

BLIP-2的核心创新是使用Q-Former作为模态桥接:

┌──────────┐      ┌──────────┐      ┌──────────┐
│   CLIP   │ ←→  │  Q-Former │ ←→  │    LLM   │
│ (冻结)   │      │  (训练)   │      │  (冻结)  │
└──────────┘      └──────────┘      └──────────┘
 视觉特征          跨模态学习         语言生成
class BLIP2(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 冻结视觉编码器
        self.visual_encoder = load_pretrained_clip()
        for param in self.visual_encoder.parameters():
            param.requires_grad = False
        
        # Q-Former: 可学习
        self.Qformer = QFormer(
            hidden_dim=768,
            num_query_tokens=32,  # 瓶颈维度
            num_heads=12,
            num_layers=6
        )
        
        # 冻结LLM
        self.llm = load_pretrained_llm()
        for param in self.llm.parameters():
            param.requires_grad = False
        
        # 投影层
        self.llm_proj = nn.Linear(768, llm.hidden_size)
    
    def forward(self, images, input_ids, training_mode="pretrain"):
        # 视觉编码
        image_embeds = self.visual_encoder(images)
        
        # Q-Former: 提取query表示
        query_output = self.Qformer(
            query_embeds=image_embeds,
            num_query_tokens=32
        )
        
        # 投影到LLM空间
        query_embeds = self.llm_proj(query_output)
        
        if training_mode == "pretrain":
            # 预训练阶段:使用语言建模损失
            outputs = self.llm(
                inputs_embeds=query_embeds,
                labels=input_ids
            )
            return outputs.loss
        else:
            # 推理阶段:直接生成
            return self.llm.generate(inputs_embeds=query_embeds)

预训练数据集

常用数据集

数据集规模类型
CC3M3M图文对
CC12M12M图文对
SBU1M图文对
COCO113K图像+描述+检测
VG100K图像+描述+关系
LAION-2B2B+大规模图文对
LAION-5B5B+超大规模

数据过滤

def filter_image_text_pairs(images, texts, min_length=3, max_length=77):
    """过滤低质量图文对"""
    filtered_pairs = []
    
    for img, txt in zip(images, texts):
        # 文本长度过滤
        if not (min_length <= len(txt.split()) <= max_length):
            continue
        
        # 图像质量检查(可选)
        if img.width < 224 or img.height < 224:
            continue
        
        filtered_pairs.append((img, txt))
    
    return filtered_pairs

下游任务迁移

零样本迁移

预训练模型可以直接用于零样本任务:

# CLIP风格的零样本分类
def zero_shot_classification(model, image, class_names):
    # 编码
    image_feat = model.encode_image(image)
    text_feats = [model.encode_text(f"a photo of a {c}") for c in class_names]
    
    # 计算相似度
    similarities = [cosine_sim(image_feat, tf) for tf in text_feats]
    return class_names[argmax(similarities)]

微调迁移

def finetune_vlp(model, train_data, num_epochs=10):
    """下游任务微调"""
    # 解冻特定层
    for name, param in model.named_parameters():
        if "text_encoder.layer.11" in name or "vision_encoder.layer.11" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False
    
    # 训练
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=1e-5
    )
    
    for epoch in range(num_epochs):
        for batch in train_data:
            loss = model(**batch)
            loss.backward()
            optimizer.step()

与现有内容的衔接

关联内容
CLIP对比学习基础
Transformer视觉Transformer架构
表达能力多模态表示学习理论
MoE多模态中的稀疏专家

参考文献