CLIP:对比语言-图像预训练

CLIP(Contrastive Language-Image Pre-training)由OpenAI于2021年提出,是多模态学习领域的里程碑工作。1 它通过自然语言监督信号学习视觉概念,实现了灵活的零样本分类能力。

核心思想

传统视觉 vs CLIP

传统方法CLIP
固定类别标签自由文本描述
需要大量标注数据利用互联网图文数据
每个任务独立训练一次训练,多任务泛化
无法处理开放词汇自然支持开放词汇

学习范式

CLIP的核心任务是图文匹配:判断给定的图像和文本描述是否对应。

图像 → Vision Encoder → $I$
文本 → Text Encoder  → $T$

预测匹配:$\text{sim}(I, T) \approx 1$(匹配)或 $0$(不匹配)

模型架构

双编码器设计

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class CLIPModel(nn.Module):
    def __init__(self, embed_dim=512, vision_dim=768, text_dim=512, 
                 num_heads=8, vision_width=768, text_width=512):
        super().__init__()
        self.image_encoder = VisionTransformer(embed_dim)
        self.text_encoder = TextTransformer(embed_dim)
        self.temperature = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
    
    def encode_image(self, images):
        return self.image_encoder(images)
    
    def encode_text(self, texts):
        return self.text_encoder(texts)
    
    def forward(self, images, texts):
        # 双编码器前向
        image_features = self.encode_image(images)
        text_features = self.encode_text(texts)
        
        # L2归一化
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)
        
        # 对比 logits
        logit_scale = self.temperature.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()
        
        return logits_per_image, logits_per_text, image_features, text_features

Vision Encoder:ViT

class VisionTransformer(nn.Module):
    def __init__(self, embed_dim, image_size=224, patch_size=32, 
                 width=768, layers=12, heads=12):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(3, width, kernel_size=patch_size, 
                                     stride=patch_size)
        
        # Class token & position embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, width))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, width))
        
        # Transformer blocks
        self.transformer = nn.ModuleList([
            TransformerBlock(width, heads) 
            for _ in range(layers)
        ])
        
        self.ln_final = nn.LayerNorm(width)
        self.proj = nn.Linear(width, embed_dim)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding: (B, 3, 224, 224) -> (B, 768, 7, 7) -> (B, 768, 49)
        x = self.patch_embed(x).flatten(2).transpose(1, 2)
        
        # Add cls token and position embedding
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.pos_embed
        
        # Transformer blocks
        for block in self.transformer:
            x = block(x)
        
        x = self.ln_final(x)
        # Return CLS token
        return self.proj(x[:, 0])

Text Encoder:Transformer

class TextTransformer(nn.Module):
    def __init__(self, embed_dim, vocab_size=49408, width=512, 
                 layers=12, heads=8, max_len=77):
        super().__init__()
        self.embed_dim = embed_dim
        
        # Token embedding & position encoding
        self.token_embedding = nn.Embedding(vocab_size, width)
        self.pos_embed = nn.Parameter(torch.zeros(1, max_len, width))
        
        # Transformer blocks
        self.transformer = nn.ModuleList([
            TransformerBlock(width, heads)
            for _ in range(layers)
        ])
        
        self.ln_final = nn.LayerNorm(width)
        self.proj = nn.Linear(width, embed_dim)
    
    @property
    def device(self):
        return next(self.parameters()).device
    
    def forward(self, texts):
        # texts: (B, L) token IDs with EOS token
        x = self.token_embedding(texts)
        x = x + self.pos_embed[:, :x.size(1)]
        
        for block in self.transformer:
            x = block(x)
        
        x = self.ln_final(x)
        # Use EOS token representation (usually last token)
        return self.proj(x[:, -1])

训练目标

对称对比损失

CLIP使用对称的InfoNCE损失,同时优化图像→文本和文本→图像两个方向:

def clip_loss(logits_per_image, logits_per_text):
    """
    对称交叉熵损失
    logits_per_image: (batch_size, batch_size) - 图像预测文本
    logits_per_text: (batch_size, batch_size) - 文本预测图像
    """
    batch_size = logits_per_image.shape[0]
    
    # 标签:每个样本的正样本是对角线上的配对
    labels = torch.arange(batch_size, device=logits_per_image.device)
    
    # 图像→文本损失
    loss_image = F.cross_entropy(logits_per_image, labels)
    # 文本→图像损失
    loss_text = F.cross_entropy(logits_per_text, labels)
    
    # 对称损失
    return (loss_image + loss_text) / 2

InfoNCE损失详解

对于图像到文本的预测,InfoNCE损失为:

其中 是温度参数(CLIP中 ), 是余弦相似度。

详见 InfoNCE与对比学习

零样本分类

提示工程

CLIP的零样本分类通过文本提示实现:

def zero_shot_classify(image, class_names, model, processor):
    """
    class_names: 类别名称列表,如 ["a cat", "a dog", "a car"]
    """
    # 构建提示模板
    prompts = [f"a photo of a {name}" for name in class_names]
    
    # 编码
    inputs = processor(text=prompts, images=image, return_tensors="pt", padding=True)
    with torch.no_grad():
        image_features = model.encode_image(inputs['pixel_values'])
        text_features = model.encode_text(inputs['input_ids'])
    
    # 计算相似度
    image_features = F.normalize(image_features, dim=-1)
    text_features = F.normalize(text_features, dim=-1)
    
    logits = (image_features @ text_features.t()) * model.temperature.exp()
    probs = logits.softmax(dim=-1)
    
    return probs

提示模板

# 不同任务使用不同提示
imagenet_template = [
    lambda c: f'a photo of a {c}.',
    lambda c: f'a bad photo of a {c}.',
    lambda c: f'a photo of many {c}.',
    # ...
]
 
# 集成多个提示
def ensemble_prompts(class_names, templates):
    prompts = []
    for name in class_names:
        for template in templates:
            prompts.append(template(name))
    return prompts

CLIP变体

开源复现

模型特点规模
OpenCLIP开源复现,支持更多架构多种规模
EVA-CLIPEVA预训练视觉编码器大规模
DFN-CLIP数据过滤网络优化数据数据驱动
CLIPA注意力层面的对比学习高效

架构变体

变体改进点
ViT-L大视觉Transformer
ViT-H更大规模
ConvNeXt-ViTConvNeXt作为视觉编码器

扩展:区域级CLIP

区域-文本对比

标准CLIP只能处理整图对齐,区域级CLIP(RegionCLIP、CLIP-Region)支持区域与文本的对齐:

class RegionCLIP:
    def __init__(self, clip_model):
        self.clip = clip_model
        self.detector = FasterRCNN()
    
    def match_region_text(self, image, regions, texts):
        """
        将图像区域与文本描述匹配
        """
        # 检测图像区域
        boxes, features = self.detector(image, regions)
        
        # 区域特征归一化
        region_features = F.normalize(features, dim=-1)
        text_features = F.normalize(
            self.clip.encode_text(texts), dim=-1
        )
        
        # 相似度矩阵
        similarity = region_features @ text_features.t()
        return similarity

与现有内容的衔接

关联内容
Transformer文本编码器中的多头注意力
对比学习InfoNCE损失函数
表达能力双编码器表示空间
神经网络层ViT中的线性投影

代码:完整训练示例

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
 
def train_clip(model, train_loader, optimizer, epochs, device):
    model = model.to(device)
    
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            images, texts = batch
            images = images.to(device)
            texts = texts.to(device)
            
            # 前向传播
            logits_per_image, logits_per_text, _, _ = model(images, texts)
            
            # 计算对比损失
            loss = clip_loss(logits_per_image, logits_per_text)
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch}: Loss = {total_loss / len(train_loader):.4f}")
 
# 简化损失函数
def clip_loss(logits_per_image, logits_per_text):
    batch_size = logits_per_image.shape[0]
    labels = torch.arange(batch_size, device=logits_per_image.device)
    loss_i = F.cross_entropy(logits_per_image, labels)
    loss_t = F.cross_entropy(logits_per_text, labels)
    return (loss_i + loss_t) / 2

参考文献

Footnotes

  1. Radford et al., Learning Transferable Visual Models From Natural Language Supervision, ICML 2021