CLIP:对比语言-图像预训练
CLIP(Contrastive Language-Image Pre-training)由OpenAI于2021年提出,是多模态学习领域的里程碑工作。1 它通过自然语言监督信号学习视觉概念,实现了灵活的零样本分类能力。
核心思想
传统视觉 vs CLIP
| 传统方法 | CLIP |
|---|---|
| 固定类别标签 | 自由文本描述 |
| 需要大量标注数据 | 利用互联网图文数据 |
| 每个任务独立训练 | 一次训练,多任务泛化 |
| 无法处理开放词汇 | 自然支持开放词汇 |
学习范式
CLIP的核心任务是图文匹配:判断给定的图像和文本描述是否对应。
图像 → Vision Encoder → $I$
文本 → Text Encoder → $T$
预测匹配:$\text{sim}(I, T) \approx 1$(匹配)或 $0$(不匹配)
模型架构
双编码器设计
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class CLIPModel(nn.Module):
def __init__(self, embed_dim=512, vision_dim=768, text_dim=512,
num_heads=8, vision_width=768, text_width=512):
super().__init__()
self.image_encoder = VisionTransformer(embed_dim)
self.text_encoder = TextTransformer(embed_dim)
self.temperature = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
def encode_image(self, images):
return self.image_encoder(images)
def encode_text(self, texts):
return self.text_encoder(texts)
def forward(self, images, texts):
# 双编码器前向
image_features = self.encode_image(images)
text_features = self.encode_text(texts)
# L2归一化
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
# 对比 logits
logit_scale = self.temperature.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text, image_features, text_featuresVision Encoder:ViT
class VisionTransformer(nn.Module):
def __init__(self, embed_dim, image_size=224, patch_size=32,
width=768, layers=12, heads=12):
super().__init__()
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
# Patch embedding
self.patch_embed = nn.Conv2d(3, width, kernel_size=patch_size,
stride=patch_size)
# Class token & position embedding
self.cls_token = nn.Parameter(torch.zeros(1, 1, width))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, width))
# Transformer blocks
self.transformer = nn.ModuleList([
TransformerBlock(width, heads)
for _ in range(layers)
])
self.ln_final = nn.LayerNorm(width)
self.proj = nn.Linear(width, embed_dim)
def forward(self, x):
B = x.shape[0]
# Patch embedding: (B, 3, 224, 224) -> (B, 768, 7, 7) -> (B, 768, 49)
x = self.patch_embed(x).flatten(2).transpose(1, 2)
# Add cls token and position embedding
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
x = x + self.pos_embed
# Transformer blocks
for block in self.transformer:
x = block(x)
x = self.ln_final(x)
# Return CLS token
return self.proj(x[:, 0])Text Encoder:Transformer
class TextTransformer(nn.Module):
def __init__(self, embed_dim, vocab_size=49408, width=512,
layers=12, heads=8, max_len=77):
super().__init__()
self.embed_dim = embed_dim
# Token embedding & position encoding
self.token_embedding = nn.Embedding(vocab_size, width)
self.pos_embed = nn.Parameter(torch.zeros(1, max_len, width))
# Transformer blocks
self.transformer = nn.ModuleList([
TransformerBlock(width, heads)
for _ in range(layers)
])
self.ln_final = nn.LayerNorm(width)
self.proj = nn.Linear(width, embed_dim)
@property
def device(self):
return next(self.parameters()).device
def forward(self, texts):
# texts: (B, L) token IDs with EOS token
x = self.token_embedding(texts)
x = x + self.pos_embed[:, :x.size(1)]
for block in self.transformer:
x = block(x)
x = self.ln_final(x)
# Use EOS token representation (usually last token)
return self.proj(x[:, -1])训练目标
对称对比损失
CLIP使用对称的InfoNCE损失,同时优化图像→文本和文本→图像两个方向:
def clip_loss(logits_per_image, logits_per_text):
"""
对称交叉熵损失
logits_per_image: (batch_size, batch_size) - 图像预测文本
logits_per_text: (batch_size, batch_size) - 文本预测图像
"""
batch_size = logits_per_image.shape[0]
# 标签:每个样本的正样本是对角线上的配对
labels = torch.arange(batch_size, device=logits_per_image.device)
# 图像→文本损失
loss_image = F.cross_entropy(logits_per_image, labels)
# 文本→图像损失
loss_text = F.cross_entropy(logits_per_text, labels)
# 对称损失
return (loss_image + loss_text) / 2InfoNCE损失详解
对于图像到文本的预测,InfoNCE损失为:
其中 是温度参数(CLIP中 ), 是余弦相似度。
详见 InfoNCE与对比学习。
零样本分类
提示工程
CLIP的零样本分类通过文本提示实现:
def zero_shot_classify(image, class_names, model, processor):
"""
class_names: 类别名称列表,如 ["a cat", "a dog", "a car"]
"""
# 构建提示模板
prompts = [f"a photo of a {name}" for name in class_names]
# 编码
inputs = processor(text=prompts, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
image_features = model.encode_image(inputs['pixel_values'])
text_features = model.encode_text(inputs['input_ids'])
# 计算相似度
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
logits = (image_features @ text_features.t()) * model.temperature.exp()
probs = logits.softmax(dim=-1)
return probs提示模板
# 不同任务使用不同提示
imagenet_template = [
lambda c: f'a photo of a {c}.',
lambda c: f'a bad photo of a {c}.',
lambda c: f'a photo of many {c}.',
# ...
]
# 集成多个提示
def ensemble_prompts(class_names, templates):
prompts = []
for name in class_names:
for template in templates:
prompts.append(template(name))
return promptsCLIP变体
开源复现
| 模型 | 特点 | 规模 |
|---|---|---|
| OpenCLIP | 开源复现,支持更多架构 | 多种规模 |
| EVA-CLIP | EVA预训练视觉编码器 | 大规模 |
| DFN-CLIP | 数据过滤网络优化数据 | 数据驱动 |
| CLIPA | 注意力层面的对比学习 | 高效 |
架构变体
| 变体 | 改进点 |
|---|---|
| ViT-L | 大视觉Transformer |
| ViT-H | 更大规模 |
| ConvNeXt-ViT | ConvNeXt作为视觉编码器 |
扩展:区域级CLIP
区域-文本对比
标准CLIP只能处理整图对齐,区域级CLIP(RegionCLIP、CLIP-Region)支持区域与文本的对齐:
class RegionCLIP:
def __init__(self, clip_model):
self.clip = clip_model
self.detector = FasterRCNN()
def match_region_text(self, image, regions, texts):
"""
将图像区域与文本描述匹配
"""
# 检测图像区域
boxes, features = self.detector(image, regions)
# 区域特征归一化
region_features = F.normalize(features, dim=-1)
text_features = F.normalize(
self.clip.encode_text(texts), dim=-1
)
# 相似度矩阵
similarity = region_features @ text_features.t()
return similarity与现有内容的衔接
| 关联 | 内容 |
|---|---|
| Transformer | 文本编码器中的多头注意力 |
| 对比学习 | InfoNCE损失函数 |
| 表达能力 | 双编码器表示空间 |
| 神经网络层 | ViT中的线性投影 |
代码:完整训练示例
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
def train_clip(model, train_loader, optimizer, epochs, device):
model = model.to(device)
for epoch in range(epochs):
total_loss = 0
for batch in train_loader:
images, texts = batch
images = images.to(device)
texts = texts.to(device)
# 前向传播
logits_per_image, logits_per_text, _, _ = model(images, texts)
# 计算对比损失
loss = clip_loss(logits_per_image, logits_per_text)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}: Loss = {total_loss / len(train_loader):.4f}")
# 简化损失函数
def clip_loss(logits_per_image, logits_per_text):
batch_size = logits_per_image.shape[0]
labels = torch.arange(batch_size, device=logits_per_image.device)
loss_i = F.cross_entropy(logits_per_image, labels)
loss_t = F.cross_entropy(logits_per_text, labels)
return (loss_i + loss_t) / 2参考文献
Footnotes
-
Radford et al., Learning Transferable Visual Models From Natural Language Supervision, ICML 2021 ↩