对比学习与InfoNCE

对比学习(Contrastive Learning)是自监督学习的核心技术之一,通过对比正样本对和负样本对来学习数据的有效表示。InfoNCE(Noise-Contrastive Estimation)是其中最重要的损失函数之一,源于互信息的下界估计。

核心思想

对比学习的目标

学习一个表示 ,使得:

  • 正样本对 的表示尽量相似
  • 负样本对 的表示尽量不同
正样本对: 同一图像的不同增强视图
         x ────────(+)───▶ z
         x_aug ──(+)───▶ z_aug
         
负样本对: 不同图像的表示
         x ────────(-)───▶ z
         x_neg ──(-)───▶ z_neg

对比损失的形式

常见的对比损失包括:

损失函数公式特点
Contrastive Loss二元形式
Triplet Loss三元组
NT-Xent无监督
InfoNCE基于互信息估计理论基础更强

InfoNCE 损失函数

从互信息到InfoNCE

InfoNCE 损失源自 Noise-Contrastive Estimation (NCE) 方法,核心思想是将密度估计问题转化为分类问题。

推导过程

目标:估计互信息

已知:互信息可以表示为

问题:分母 未知且难以计算。

解决方案:使用 NCE 将 替换为噪声分布

InfoNCE 损失

对于一个 batch 中的 个样本,使用 个负样本:

其中:

  • 相似度函数(通常用余弦相似度)
  • 温度参数,控制分布的 sharpness
  • 的正样本(正样本对)
  • 包括 的正样本和 个负样本

代码实现

import torch
import torch.nn.functional as F
import torch.nn as nn
 
def info_nce_loss(z_i, z_j, temperature=0.07):
    """
    InfoNCE 损失函数
    
    Args:
        z_i: 第一个视图的表示 (batch_size, d)
        z_j: 第二个视图的表示 (batch_size, d)
        temperature: 温度参数 τ
    
    Returns:
        loss: InfoNCE 损失
    """
    batch_size = z_i.shape[0]
    
    # L2 归一化
    z_i = F.normalize(z_i, dim=1)
    z_j = F.normalize(z_j, dim=1)
    
    # 计算相似度矩阵
    # [2N, 2N] 矩阵:两个视图拼接后计算
    z = torch.cat([z_i, z_j], dim=0)  # (2N, d)
    sim_matrix = torch.matmul(z, z.T) / temperature  # (2N, 2N)
    
    # 对角线是同一图像两个视图的相似度(正样本)
    # 我们需要区分哪些是正对,哪些是负对
    
    # 创建 mask:正样本对的位置
    # (i, i+N) 和 (i+N, i) 是正对
    N = batch_size
    labels = torch.arange(N, device=z_i.device)
    
    # 组合相似度: [z_i|z_j] 和 [z_j|z_i] 都要计算
    # 对于 z_i,正样本是 z_j(索引 N+i)
    # 对于 z_j,正样本是 z_i(索引 i)
    
    # 方式1:SimCLR 风格
    sim_ij = torch.sum(z_i * z_j, dim=1) / temperature  # (N,)
    sim_ji = sim_ij  # 对称
    
    # 构建全部 logits
    # row i 对应 z_i concatenation 后的第 i 行
    logits = torch.cat([
        torch.cat([z_i[:0], z_i], dim=0),   # z_i 与 z_i 的相似度(排除自身)
        torch.cat([z_j, z_j[:0]], dim=0)    # z_j 与 z_j 的相似度(排除自身)
    ], dim=1) / temperature
    
    # 简化实现
    def contrastive_loss(z_i, z_j, temperature):
        """
        简洁的 SimCLR 风格 InfoNCE 实现
        """
        batch_size = z_i.shape[0]
        
        # 归一化
        z_i = F.normalize(z_i, dim=1)
        z_j = F.normalize(z_j, dim=1)
        
        # 相似度矩阵
        # [z_i, z_j] 是 2N x d
        # [z_i, z_j]^T @ [z_i, z_j] 是 2N x 2N
        z = torch.cat([z_i, z_j], dim=0)
        sim = torch.mm(z, z.T) / temperature
        
        # 掩码:排除自身与跨视图
        # 正样本:(i, i+N) 和 (i+N, i)
        N = batch_size
        sim_i_pos = torch.diag(sim, N)      # z_i 与对应 z_j 的相似度
        sim_j_pos = torch.diag(sim, -N)     # z_j 与对应 z_i 的相似度
        
        # 拼接成正样本 logits
        pos = torch.cat([sim_i_pos, sim_j_pos], dim=0)
        
        # 所有 logits(包含自身)需要掩码
        mask = torch.eye(2 * N, device=sim.device, dtype=torch.bool)
        sim.masked_fill_(mask, -float('inf'))
        
        # 负样本 logits
        neg = sim.logsumexp(dim=1)
        
        # InfoNCE 损失
        loss = -torch.mean(pos - neg)
        
        return loss
    
    return contrastive_loss(z_i, z_j, temperature)

PyTorch Metric Learning 实现

from pytorch_metric_learning import distances, losses, reducers, testers
from pytorch_metric_learning.utils.accuracy_control import NVRLoss
 
# 使用预置的 InfoNCE loss
class InfoNCELoss(nn.Module):
    def __init__(self, temperature=0.07, use_cosine_similarity=True):
        super().__init__()
        self.temperature = temperature
        self.use_cosine_similarity = use_cosine_similarity
        
        if use_cosine_similarity:
            self.distance = distances.CosineSimilarity()
        else:
            self.distance = distances.LpDistance()
    
    def forward(self, embeddings, labels):
        """
        Args:
            embeddings: (N, d) 表示向量
            labels: (N,) 类别标签(用于确定正负样本)
        """
        reducer = reducers.MeanReducer()
        loss_func = losses.NTXentLoss(
            temperature=self.temperature,
            distance=self.distance,
            reducer=reducer
        )
        return loss_func(embeddings, labels)

InfoNCE 与互信息的关系

理论联系

关键结论:InfoNCE 损失是互信息的下界估计

推导

对于正样本 和负样本

时:

这正是 的下界。

下界性质的意义

特性含义
下界保证最小化 InfoNCE 损失 → 最大化互信息下界 → 最大化真实互信息
样本效率负样本数 越大,下界越紧
温度作用 影响分布的 sharpness

温度参数的作用

的影响

分布特性效果
(e.g., 0.01-0.1)sharp,重点关注最相似的负样本学习更细粒度的区分
中等 (e.g., 0.5-1.0)平滑,平衡正负样本标准设置
(e.g., >1.0)uniform,几乎所有负样本同等重要关注全局结构

温度调度的策略

class CosineAnnealingTemperature:
    """余弦退火温度调度"""
    
    def __init__(self, T_max, T_min=0.01):
        self.T_max = T_max
        self.T_min = T_min
    
    def get_temperature(self, epoch):
        return self.T_min + 0.5 * (self.T_max - self.T_min) * \
               (1 + np.cos(np.pi * epoch / self.T_max))

典型对比学习方法

SimCLR

SimCLR(Simple Contrastive Learning of Visual Representations)是经典的对比学习方法:1

class SimCLR(nn.Module):
    """
    SimCLR: 简化对比学习框架
    
    流程:
    1. 图像 x 经过两种数据增强 -> x_i, x_j
    2. 编码器 f(·) -> 表示 h_i, h_j
    3. 投影头 g(·) -> 表示 z_i, z_j
    4. InfoNCE 损失
    """
    
    def __init__(self, encoder, projection_dim=128):
        super().__init__()
        self.encoder = encoder
        self.projection_head = nn.Sequential(
            nn.Linear(encoder.output_dim, 256),
            nn.ReLU(),
            nn.Linear(256, projection_dim)
        )
    
    def forward(self, x):
        # 两个视图
        x_i = self.augment(x)
        x_j = self.augment(x)
        
        # 编码
        h_i = self.encoder(x_i)
        h_j = self.encoder(x_j)
        
        # 投影
        z_i = self.projection_head(h_i)
        z_j = self.projection_head(h_j)
        
        # InfoNCE 损失
        loss = info_nce_loss(z_i, z_j, temperature=0.07)
        
        return loss, z_i, z_j
    
    def get_representations(self, x):
        """获取表示用于下游任务"""
        return self.encoder(x)
    
    @staticmethod
    def augment(x):
        """数据增强"""
        # SimCLR 使用的增强组合
        transforms = [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=23),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                std=[0.229, 0.224, 0.225])
        ]
        return transforms

MoCo

MoCo(Momentum Contrast)使用动量更新的队列维护大量负样本:2

class MoCo(nn.Module):
    """
    MoCo: 动量对比学习
    
    关键创新:
    - 使用动量编码器维持一致的负样本表示
    - 使用队列存储大量负样本
    """
    
    def __init__(self, encoder, projection_dim=128, K=65536, m=0.999):
        super().__init__()
        self.K = K  # 负样本队列大小
        self.m = m  # 动量系数
        
        # 查询编码器
        self.encoder_q = encoder
        self.projection_q = nn.Linear(encoder.output_dim, projection_dim)
        
        # 键编码器(动量更新)
        self.encoder_k = copy.deepcopy(encoder)
        self.projection_k = nn.Linear(encoder.output_dim, projection_dim)
        
        for param_q, param_k in zip(
            self.encoder_q.parameters(), 
            self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False
        
        # 负样本队列
        self.register_buffer('queue', torch.randn(projection_dim, K))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
    
    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """动量更新键编码器"""
        for param_q, param_k in zip(
            self.encoder_q.parameters(),
            self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1 - self.m)
    
    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        """更新负样本队列"""
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        
        self.queue[:, ptr:ptr+batch_size] = keys.T
        self.queue_ptr[0] = (ptr + batch_size) % self.K
    
    def forward(self, x_q, x_k):
        # 查询
        q = self.projection_q(self.encoder_q(x_q))
        q = F.normalize(q, dim=1)
        
        # 键(动量编码器,不梯度更新)
        with torch.no_grad():
            k = self.projection_k(self.encoder_k(x_k))
            k = F.normalize(k, dim=1)
        
        # 正样本 logits
        l_pos = torch.einsum('nc,nc->n', q, k).unsqueeze(-1)
        
        # 负样本 logits
        l_neg = torch.einsum('nc,ck->nk', q, self.queue.clone())
        
        # 总 logits
        logits = torch.cat([l_pos, l_neg], dim=1) / 0.07
        
        # 标签(全0,正样本在第一个位置)
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
        
        # 交叉熵损失
        loss = F.cross_entropy(logits, labels)
        
        # 更新键编码器和队列
        self._momentum_update_key_encoder()
        self._dequeue_and_enqueue(k)
        
        return loss

BYOL 和 SimSiam

BYOL 和 SimSiam 采用孪生网络架构,不需要负样本:3

class SimSiam(nn.Module):
    """
    SimSiam: 简化的孪生网络
    
    关键创新:
    - 不需要负样本
    - 使用 stop-gradient 避免崩溃解
    - 预测器网络增强表示
    """
    
    def __init__(self, encoder, projection_dim=2048, prediction_dim=512):
        super().__init__()
        
        self.encoder = encoder
        
        # 投影网络
        self.projection = nn.Sequential(
            nn.Linear(encoder.output_dim, projection_dim),
            nn.BatchNorm1d(projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim)
        )
        
        # 预测网络
        self.predictor = nn.Sequential(
            nn.Linear(projection_dim, prediction_dim),
            nn.BatchNorm1d(prediction_dim),
            nn.ReLU(),
            nn.Linear(prediction_dim, projection_dim)
        )
    
    def forward(self, x1, x2):
        # 编码
        r1 = self.projection(self.encoder(x1))
        r2 = self.projection(self.encoder(x2))
        
        # 预测
        p1 = self.predictor(r1)
        p2 = self.predictor(r2)
        
        # 损失:对称的均方误差
        # 注意:r1, r2 使用 stop-gradient
        loss = 0.5 * (
            F.mse_loss(p1, r2.detach()) + 
            F.mse_loss(p2, r1.detach())
        )
        
        return loss

扩展与改进

Hard Negative Mining

def info_nce_with_hard_negatives(z_i, z_j, temperature=0.07, hard_ratio=0.5):
    """
    InfoNCE with Hard Negative Mining
    
    选择最难的负样本(相似度最高的)进行更强烈的惩罚
    """
    z_i = F.normalize(z_i, dim=1)
    z_j = F.normalize(z_j, dim=1)
    
    N = z_i.shape[0]
    z = torch.cat([z_i, z_j], dim=0)
    
    # 相似度矩阵
    sim = torch.mm(z, z.T) / temperature
    
    # 掩码
    mask = torch.eye(2 * N, dtype=torch.bool, device=sim.device)
    sim.masked_fill_(mask, -float('inf'))
    
    # 选取 hard negatives
    top_k = int(N * hard_ratio)
    hard_sim, _ = sim[:, N:].topk(top_k, dim=1)
    
    # 只对 hard negatives 计算损失
    logits = torch.cat([
        torch.diag(sim[:N, N:]).unsqueeze(1),  # 正样本
        hard_sim  # hard 负样本
    ], dim=1)
    
    labels = torch.zeros(N, dtype=torch.long, device=logits.device)
    
    return F.cross_entropy(logits, labels)

对比损失的正则化

class RegularizedInfoNCE(nn.Module):
    """
    带正则化的 InfoNCE
    
    添加:
    - 方差正则: 避免表示坍塌
    - 均匀性正则: 鼓励表示均匀分布
    - 对齐性正则: 正样本表示应该对齐
    """
    
    def __init__(self, temperature=0.07, lambda_uniform=0.1, lambda_align=0.1):
        super().__init__()
        self.temperature = temperature
        self.lambda_uniform = lambda_uniform
        self.lambda_align = lambda_align
    
    def uniform_loss(self, z):
        """均匀性损失:基于 jensen-Shannon 散度"""
        z = F.normalize(z, dim=1)
        batch_size = z.shape[0]
        
        # 归一化后的点积
        pairwise_sim = torch.mm(z, z.T)
        
        # 均匀性:所有点应该等间距
        # 使用方差作为近似
        variance = torch.var(pairwise_sim)
        return variance
    
    def align_loss(self, z_i, z_j):
        """对齐性损失:正样本应该接近"""
        return F.mse_loss(z_i, z_j)
    
    def forward(self, z_i, z_j):
        # 标准 InfoNCE
        nce = info_nce_loss(z_i, z_j, self.temperature)
        
        # 均匀性正则
        uniform = self.uniform_loss(torch.cat([z_i, z_j], dim=0))
        
        # 对齐性正则
        align = self.align_loss(
            F.normalize(z_i, dim=1), 
            F.normalize(z_j, dim=1)
        )
        
        return nce + self.lambda_uniform * uniform + self.lambda_align * align

核心公式速查

概念公式
InfoNCE
余弦相似度
互信息下界
温度作用 → sharp 分布;大 → uniform 分布

参考

相关文章

Footnotes

  1. Chen, T., Kornblith, S., Norouzi, M., & Hinton, G. (2020). “A Simple Framework for Contrastive Learning of Visual Representations”. ICML.

  2. He, K., Fan, H., Wu, Y., Xie, S., & Girshick, R. (2020). “Momentum Contrast for Unsupervised Visual Representation Learning”. CVPR.

  3. Grill, J.B., et al. (2020). “Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning”. NeurIPS.