信息瓶颈变体与表示学习

信息瓶颈(Information Bottleneck, IB)理论自1999年提出以来,已经发展出一个丰富的变体家族,用于解决不同场景下的表示学习问题。1本文系统梳理各类IB变体的数学形式、核心思想和实践方法,揭示其与表示学习的深层联系。

预备知识

基本设定

考虑随机变量三元组 ,满足 Markov 链 ,即:

  • 完全由 决定(
  • 给定 时, 条件独立

信息平面

为坐标的二维平面:

I(Y;T)
  ↑
  │    · · · · · · IB曲线 · · · · ·
  │   ·                              ·
  │  ·                                ·
  │ ·                                  ·
  │·                                    ·
  │                                      ·
  └──────────────────────────────────────→ I(X;T)

更多信息请参考 信息瓶颈理论


1. 原始信息瓶颈(Original IB)

1.1 目标函数

原始IB的核心思想是找到压缩且任务相关的表示 。形式化为约束优化:

使用拉格朗日乘子法转化为无约束形式:

其中 控制压缩与信息保留之间的权衡:

行为
只关注信息保留, 保留所有 的信息
只关注压缩,完全忽略 的信息

1.2 IB曲线的Pareto最优性

IB曲线(Information Bottleneck Curve)是 Pareto 最优前沿:

Pareto 最优性证明

对于任意两个可行解 ,如果:

则两者都在 Pareto 前沿上,无法同时改进两个目标。

1.3 自洽方程推导

使用变分法求解IB优化问题。引入辅助分布 (类似变分后验),构建拉格朗日函数:

展开互信息:

施加变分,得到 自洽方程(Self-Consistent Equation):

其中 是归一化常数:

直观的解释:对于输入 ,编码器将高概率分配给:

  • 先验 较大的区域
  • 预测分布 接近的区域(KL散度小)

1.4 固定点迭代

自洽方程可通过迭代求解:

def ib_fixed_point_iteration(p_xy, beta, n_iter=100):
    """
    原始IB的固定点迭代算法
    
    Args:
        p_xy: 联合分布 p(x,y)
        beta: 拉格朗日乘子
        n_iter: 迭代次数
    
    Returns:
        p_t_given_x: 编码分布 p(t|x)
    """
    n_x = p_xy.shape[0]
    n_y = p_xy.shape[1]
    n_t = n_x  # 潜在空间大小
    
    # 初始化
    p_t = np.ones(n_t) / n_t
    p_y_given_t = np.ones((n_t, n_y)) / n_y
    
    for _ in range(n_iter):
        # E步:更新 p(t|x)
        p_t_given_x = np.zeros((n_x, n_t))
        for x in range(n_x):
            kl_div = kl_divergence(p_xy[x] / p_xy[x].sum(), p_y_given_t)  # D_KL(p(y|x)||p(y|t))
            unnorm = p_t * np.exp(-beta * kl_div)
            p_t_given_x[x] = unnorm / unnorm.sum()
        
        # M步:更新 p(y|t) 和 p(t)
        p_ty = p_xy.T @ p_t_given_x  # p(y,t)
        p_t = p_ty.sum(axis=0)
        p_y_given_t = p_ty / (p_t[:, np.newaxis] + 1e-10)
        p_t /= p_t.sum()
    
    return p_t_given_x

2. 变分信息瓶颈(Variational IB, VIB)

2.1 变分近似

原始IB的难点在于:

  1. 难以精确计算
  2. 自洽方程难以解析求解
  3. 高维连续空间中需要变分近似

VIB 使用变分推断技术,将IB目标转化为可优化的下界。

2.2 目标函数推导

原始目标(最大化):

变分下界推导

第一步,引入变分分布 近似真实后验

第二步,引入先验分布 近似

综合得到 变分下界(Variational Lower Bound):

在实现中通常简化为:

2.3 与VAE的联系

VIB 与 变分自编码器(VAE) 有着深刻联系:

组件VIBVAE
隐变量后验
先验分布
重建损失
正则项

关键区别

  • VAE 关注重建输入
  • VIB 关注预测标签

数学上,VAE 的 ELBO:

如果令 ,则 VIB 与 VAE 在形式上统一。

2.4 PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import kl_divergence, Normal
 
class VariationalInformationBottleneck(nn.Module):
    """
    变分信息瓶颈(VIB)模块
    
    目标函数:
        L = E[-log q(y|z)] + beta * D_KL(q(z|x) || r(z))
    
    其中:
        - q(z|x): 变分编码器(近似 p(z|x))
        - r(z):   先验分布(通常为 N(0, I))
        - q(y|z): 分类器(近似 p(y|z))
    """
    
    def __init__(self, input_dim, latent_dim, num_classes, beta=1e-3):
        super().__init__()
        self.latent_dim = latent_dim
        self.beta = beta
        self.num_classes = num_classes
        
        # 变分编码器:输出高斯分布参数
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * latent_dim)  # [mean, log_var]
        )
        
        # 先验分布(标准高斯)
        self.register_buffer('prior_mean', torch.zeros(latent_dim))
        self.register_buffer('prior_log_var', torch.zeros(latent_dim))
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )
    
    def reparameterize(self, mean, log_var):
        """重参数化技巧:z = mu + sigma * epsilon"""
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mean + eps * std
    
    def kl_divergence(self, mean, log_var):
        """
        计算变分后验与先验的 KL 散度
        
        D_KL(N(mu, sigma) || N(0, I))
        = 0.5 * (sigma^2 + mu^2 - 1 - log(sigma^2))
        """
        prior = Normal(self.prior_mean, torch.ones_like(log_var).exp())
        posterior = Normal(mean, log_var.exp().sqrt())
        
        return kl_divergence(posterior, prior).sum(dim=-1).mean()
    
    def forward(self, x, training=True):
        """
        前向传播
        
        Args:
            x: 输入数据 (batch_size, input_dim)
            training: 是否在训练模式
        
        Returns:
            logits: 分类 logits (batch_size, num_classes)
            mean: 潜在变量均值
            log_var: 潜在变量对数方差
            z: 重参数化后的潜在变量
        """
        # 编码
        h = self.encoder(x)
        mean, log_var = h.chunk(2, dim=-1)
        
        # 重参数化采样
        if training:
            z = self.reparameterize(mean, log_var)
        else:
            z = mean  # 推理时使用均值
        
        # 分类
        logits = self.classifier(z)
        
        return logits, mean, log_var, z
    
    def loss(self, x, y):
        """
        计算 VIB 损失
        
        L = E[-log q(y|z)] + beta * D_KL(q(z|x) || r(z))
        
        Returns:
            total_loss: 总损失
            ce_loss: 交叉熵损失(信息保留项)
            kl_loss: KL 散度(压缩项)
        """
        logits, mean, log_var, z = self.forward(x, training=True)
        
        # 交叉熵损失(最大化 I(Z;Y))
        ce_loss = F.cross_entropy(logits, y, reduction='mean')
        
        # KL 散度正则(最小化 I(Z;X))
        kl_loss = self.kl_divergence(mean, log_var)
        
        # 总损失
        total_loss = ce_loss + self.beta * kl_loss
        
        return total_loss, ce_loss, kl_loss
    
    def estimate_mutual_information(self, x, y):
        """
        估计信息平面坐标
        
        I(Z; Y) >= E[log q(y|z)] + H(Y)
        I(Z; X) <= D_KL(q(z|x) || r(z))
        """
        with torch.no_grad():
            logits, mean, log_var, z = self.forward(x, training=False)
            
            # I(Z;Y) 的下界
            log_probs = F.log_softmax(logits, dim=-1)
            i_zy = torch.gather(log_probs, 1, y.unsqueeze(1)).mean()
            
            # I(Z;X) 的上界
            i_zx = self.kl_divergence(mean, log_var)
        
        return i_zy, i_zx
 
 
class VIBResNet(nn.Module):
    """
    基于 ResNet 的 VIB 模型(可用于图像分类)
    
    将标准 ResNet 的最后层替换为 VIB 模块
    """
    
    def __init__(self, backbone, latent_dim, num_classes, beta=1e-3):
        super().__init__()
        self.backbone = backbone
        self.vib = VariationalInformationBottleneck(
            input_dim=backbone.output_dim,
            latent_dim=latent_dim,
            num_classes=num_classes,
            beta=beta
        )
    
    def forward(self, x, training=True):
        features = self.backbone(x)
        return self.vib(features, training=training)

2.5 使用示例

# 训练循环
def train_vib(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_ce = 0.0
    total_kl = 0.0
    
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
        optimizer.zero_grad()
        loss, ce_loss, kl_loss = model.loss(batch_x, batch_y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_ce += ce_loss.item()
        total_kl += kl_loss.item()
    
    n_batches = len(train_loader)
    return total_loss / n_batches, total_ce / n_batches, total_kl / n_batches
 
 
# 信息平面可视化
def plot_information_plane(model, data_loader, device):
    """可视化训练过程中的信息平面轨迹"""
    import matplotlib.pyplot as plt
    
    i_zx_list, i_zy_list = [], []
    
    model.eval()
    for batch_x, batch_y in data_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        i_zy, i_zx = model.estimate_mutual_information(batch_x, batch_y)
        i_zx_list.append(i_zx.item())
        i_zy_list.append(i_zy.item())
    
    plt.figure(figsize=(8, 6))
    plt.scatter(i_zx_list, i_zy_list, alpha=0.5)
    plt.xlabel('$I(X; Z)$')
    plt.ylabel('$I(Z; Y)$')
    plt.title('Information Plane')
    plt.grid(True)
    plt.show()

3. 条件信息瓶颈(Conditional IB, CIB)

3.1 问题背景

在许多实际场景中,我们不仅关心 ,还关心 条件 变量 的影响。例如:

  • 领域适应: 表示源/目标域
  • 对抗鲁棒性: 表示对抗扰动
  • 因果推断: 表示混杂因素

3.2 条件互信息的优化

条件信息瓶颈的目标是最大化在给定 条件下 的互信息,同时最小化 的互信息:

其中条件互信息定义为:

I(Y; T \mid C) = \mathbb{E}_{p(c)}\left[I(Y; T \mid C=c)\right] = \mathbb{E}_{p(c)}\left[\mathbb{E}_{p(x,y \mid c)}}\left[\log \frac{p(y \mid t, c)}{p(y \mid c)}\right]\right]

3.3 任务相关表示学习

条件IB的核心思想是学习任务相关的表示:

  • 任务无关信息 中与 无关的部分,应该被压缩
  • 任务相关但域无关 中与 相关但与 无关的部分,应该被保留
  • 虚假相关 中与 都相关的部分(虚假相关),应该被识别和处理
class ConditionalInformationBottleneck(nn.Module):
    """
    条件信息瓶颈
    
    学习在给定条件 C 下对 Y 有用的表示
    """
    
    def __init__(self, input_dim, latent_dim, num_classes, num_domains, beta=1e-3):
        super().__init__()
        self.beta = beta
        
        # 条件编码器:输出依赖输入和条件
        self.encoder = nn.Sequential(
            nn.Linear(input_dim + num_domains, 256),  # 输入拼接条件
            nn.ReLU(),
            nn.Linear(256, 2 * latent_dim)
        )
        
        # 域无关分类器
        self.classifier = nn.Linear(latent_dim, num_classes)
        
        # 域判别器(用于对抗训练)
        self.domain_discriminator = nn.Linear(latent_dim, num_domains)
    
    def forward(self, x, c):
        """
        Args:
            x: 输入特征
            c: 条件/域标签 (one-hot 或 index)
        """
        # 拼接输入和条件
        xc = torch.cat([x, c], dim=-1)
        
        # 编码
        h = self.encoder(xc)
        mean, log_var = h.chunk(2, dim=-1)
        
        # 重参数化
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mean + eps * std
        
        # 分类
        logits = self.classifier(z)
        domain_logits = self.domain_discriminator(z)
        
        return logits, domain_logits, mean, log_var, z
    
    def loss(self, x, y, c, alpha=0.5):
        """
        条件IB损失
        
        L = E[-log q(y|z)] + beta * D_KL(q(z|x,c) || r(z))
            - alpha * E[log q(c|z)]
        
        第三项:对抗损失,鼓励表示与条件解耦
        """
        logits, domain_logits, mean, log_var, z = self.forward(x, c)
        
        # 分类损失
        ce_loss = F.cross_entropy(logits, y)
        
        # KL 正则
        prior = Normal(0, 1)
        posterior = Normal(mean, log_var.exp().sqrt())
        kl_loss = kl_divergence(posterior, prior).sum(dim=-1).mean()
        
        # 对抗损失(最大化域预测误差 => 最小化 I(Z;C))
        domain_loss = F.cross_entropy(domain_logits, c.argmax(dim=-1))
        
        # 总损失
        total_loss = ce_loss + self.beta * kl_loss - alpha * domain_loss
        
        return total_loss, ce_loss, kl_loss, domain_loss

4. 对比信息瓶颈(Contrastive IB, CIB)

4.1 与SimCLR、MoCo的联系

对比学习方法(如 InfoNCE)可以理解为一种特殊的IB变体。考虑对比学习中的数据增强:

对于表示

  • 正样本对 应该相似 最大化
  • 负样本对 应该不同 最小化虚假相关

4.2 正负样本对的信息论分析

假设增强分布 满足:

  • :正样本分布
  • :负样本分布(与 无关)

InfoNCE 损失的IB解释

当温度 时,这近似于互信息的下界:

4.3 的IB解释

从IB视角,InfoNCE 同时优化两个目标:

目标InfoNCE 实现IB 对应
信息保留最大化正样本对的相似度最大化
压缩使用归一化 + temperature限制 的增长

统一的理论框架

其中第二项对应于负样本的对比正则。

4.4 PyTorch实现

class ContrastiveInformationBottleneck(nn.Module):
    """
    对比信息瓶颈
    
    将 InfoNCE 损失解释为 IB 目标的变分近似
    """
    
    def __init__(self, encoder_dim, latent_dim, temperature=0.07, beta=1e-3):
        super().__init__()
        self.temperature = temperature
        self.beta = beta
        
        # 编码器
        self.encoder = encoder
        
        # 投影头
        self.projection = nn.Sequential(
            nn.Linear(encoder_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim)
        )
    
    def info_nce_loss(self, z_i, z_j):
        """
        InfoNCE 损失
        
        L = -log exp(s(z_i, z_j)/tau) / sum_k exp(s(z_i, z_k)/tau)
        """
        batch_size = z_i.shape[0]
        
        # 归一化
        z_i = F.normalize(z_i, dim=1)
        z_j = F.normalize(z_j, dim=1)
        
        # 拼接所有表示
        z = torch.cat([z_i, z_j], dim=0)  # (2N, d)
        
        # 相似度矩阵
        sim = torch.mm(z, z.T) / self.temperature  # (2N, 2N)
        
        # 正样本对:(i, i+N) 和 (i+N, i)
        N = batch_size
        sim_i_pos = torch.diag(sim, N)      # z_i 与对应 z_j
        sim_j_pos = torch.diag(sim, -N)     # z_j 与对应 z_i
        pos = torch.cat([sim_i_pos, sim_j_pos], dim=0)
        
        # 掩码自身
        mask = torch.eye(2 * N, dtype=torch.bool, device=sim.device)
        sim.masked_fill_(mask, -float('inf'))
        
        # 负样本
        neg = sim.logsumexp(dim=1)
        
        # InfoNCE 损失
        loss = -torch.mean(pos - neg)
        
        return loss
    
    def ib_regularization(self, z):
        """
        IB 正则:鼓励表示压缩
        
        使用方差正则近似信息压缩
        """
        # 协方差正则
        z = z - z.mean(dim=0)
        cov = (z.T @ z) / (z.shape[0] - 1)
        off_diag = cov.fill_diagonal_(0)
        
        # 鼓励对角协方差(维度解耦)
        diag_loss = off_diag.abs().mean()
        
        # 鼓励均匀分布(信息最大化)
        uniformity = torch.pdist(z, p=2).pow(2).mean()
        
        return diag_loss + uniformity
    
    def forward(self, x_i, x_j):
        """
        Args:
            x_i, x_j: 同一batch的两个增强视图
        """
        # 编码
        h_i = self.encoder(x_i)
        h_j = self.encoder(x_j)
        
        # 投影
        z_i = self.projection(h_i)
        z_j = self.projection(h_j)
        
        # InfoNCE 损失(信息保留)
        nce_loss = self.info_nce_loss(z_i, z_j)
        
        # IB 正则(压缩)
        all_z = torch.cat([z_i, z_j], dim=0)
        ib_loss = self.ib_regularization(all_z)
        
        # 总损失
        total_loss = nce_loss + self.beta * ib_loss
        
        return total_loss, nce_loss, ib_loss
 
 
# 与 MoCo 结合
class ContrastiveIBMoCo(nn.Module):
    """
    对比IB与MoCo的结合
    
    使用MoCo维护负样本队列,同时加入IB正则
    """
    
    def __init__(self, encoder, projection_dim=128, K=65536, m=0.999, beta=1e-3):
        super().__init__()
        self.K = K
        self.m = m
        self.beta = beta
        
        # 查询编码器
        self.encoder_q = encoder
        self.projection_q = nn.Linear(encoder.output_dim, projection_dim)
        
        # 键编码器(动量更新)
        self.encoder_k = copy.deepcopy(encoder)
        self.projection_k = nn.Linear(encoder.output_dim, projection_dim)
        
        # 负样本队列
        self.register_buffer('queue', torch.randn(projection_dim, K))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
    
    @torch.no_grad()
    def _momentum_update_key(self):
        """动量更新键编码器"""
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.mul_(self.m).add_(param_q.data, alpha=1 - self.m)
        
        for param_q, param_k in zip(
            self.projection_q.parameters(), self.projection_k.parameters()
        ):
            param_k.data.mul_(self.m).add_(param_q.data, alpha=1 - self.m)
    
    def forward(self, x_q, x_k):
        # 查询表示
        q = F.normalize(self.projection_q(self.encoder_q(x_q)), dim=1)
        
        # 键表示(动量编码器)
        with torch.no_grad():
            k = F.normalize(self.projection_k(self.encoder_k(x_k)), dim=1)
        
        # 正样本 logits
        l_pos = torch.einsum('nc,nc->n', q, k).unsqueeze(1)
        
        # 负样本 logits
        l_neg = torch.einsum('nc,ck->nk', q, self.queue.clone())
        
        # 总 logits
        logits = torch.cat([l_pos, l_neg], dim=1) / 0.07
        
        # 标签
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
        
        # 对比损失
        contrastive_loss = F.cross_entropy(logits, labels)
        
        # IB 正则:鼓励压缩(限制队列表示的熵)
        queue_entropy = -(self.queue * torch.log(self.queue + 1e-10)).sum(dim=0).mean()
        
        # 更新
        self._momentum_update_key()
        self._dequeue_and_enqueue(k)
        
        return contrastive_loss + self.beta * queue_entropy

5. 其他变体

5.1 任务导向IB(Task-Oriented IB, TOIB)

核心思想:不同任务需要不同级别的压缩。

传统IB对所有任务使用统一的压缩级别;TOIB根据任务特性自适应调整:

class TaskOrientedIB(nn.Module):
    """
    任务导向信息瓶颈
    
    为不同任务学习不同程度压缩的表示
    """
    
    def __init__(self, input_dim, latent_dim, num_tasks):
        super().__init__()
        self.num_tasks = num_tasks
        
        # 共享编码器
        self.shared_encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
        )
        
        # 任务特定编码器(调整压缩级别)
        self.task_encoders = nn.ModuleList([
            nn.Linear(256, latent_dim * task_beta)  # 不同任务不同潜在维度
            for task_beta in [1.0, 0.75, 0.5, 0.25]  # 压缩级别
        ])
        
        # 任务分类器
        self.task_classifiers = nn.ModuleList([
            nn.Linear(latent_dim, num_classes)
            for _ in range(num_tasks)
        ])
    
    def forward(self, x, task_id):
        # 共享特征
        h = self.shared_encoder(x)
        
        # 任务特定编码
        z = self.task_encoders[task_id](h)
        
        # 分类
        logits = self.task_classifiers[task_id](z)
        
        return logits, z

5.2 归一化IB(Normalized IB, NIB)

问题:原始IB的权衡参数 对不同数据集和模型架构敏感。

解决方案:使用归一化的互信息:

归一化后的IB具有更好的跨设置泛化能力。

5.3 不确定性感知IB(Uncertainty-Aware IB, UAIB)

核心思想:区分认知不确定性( epistemic uncertainty)和偶然不确定性( aleatoric uncertainty)。

  • 认知不确定性:由训练数据不足导致,可通过更多数据减少
  • 偶然不确定性:数据本身的固有噪声,不可减少
class UncertaintyAwareIB(nn.Module):
    """
    不确定性感知信息瓶颈
    
    分解表示中的认知和偶然不确定性
    """
    
    def __init__(self, input_dim, latent_dim, num_classes):
        super().__init__()
        
        # 共享编码器
        self.encoder = nn.Linear(input_dim, latent_dim)
        
        # 认知不确定性估计器(数据依赖)
        self.epistemic_head = nn.Linear(latent_dim, latent_dim)
        
        # 偶然不确定性估计器(输入依赖)
        self.aleatoric_head = nn.Sequential(
            nn.Linear(input_dim, latent_dim),
            nn.Linear(latent_dim, 1)  # 输出 log(sigma^2)
        )
        
        # 分类器
        self.classifier = nn.Linear(latent_dim, num_classes)
    
    def forward(self, x):
        z = self.encoder(x)
        
        # 认知不确定性
        epistemic_var = torch.exp(self.epistemic_head(z))
        
        # 偶然不确定性
        aleatoric_var = torch.exp(self.aleatoric_head(x))
        
        # 总不确定性
        total_var = epistemic_var + aleatoric_var
        
        return z, epistemic_var, aleatoric_var, total_var
    
    def loss(self, x, y, beta=1e-3):
        z, epi_var, alea_var, total_var = self.forward(x)
        
        # 带不确定性的分类损失
        logits = self.classifier(z)
        
        # NLL 损失(隐式处理不确定性)
        nll_loss = F.cross_entropy(logits, y, reduction='none')
        nll_loss = (nll_loss / (alea_var + 1e-6)).mean()
        
        # 不确定性正则
        # 鼓励高认知不确定性(表示不确信)但低偶然不确定性
        uncertainty_loss = epi_var.mean() - alea_var.mean()
        
        total_loss = nll_loss + beta * uncertainty_loss
        
        return total_loss, nll_loss, uncertainty_loss

5.4 变体对比总结

变体目标函数主要应用
原始IB理论基础
VIB$\min \mathbb{E}[-\log q(yz)] + \beta D_{KL}$
CIB领域适应
CIB (Contrastive)自监督学习
TOIB多任务学习
NIB跨设置泛化
UAIB$\min \text{NLL} + \beta \cdot (\text{epi}

6. 统一框架

6.1 IB变体的层次结构

                    ┌─────────────────┐
                    │   Information   │
                    │   Bottleneck    │
                    └────────┬────────┘
                             │
          ┌──────────────────┼──────────────────┐
          │                  │                  │
          ▼                  ▼                  ▼
    ┌──────────┐      ┌──────────┐      ┌──────────┐
    │ Variational│    │Conditional│     │Contrastive│
    │    IB    │      │    IB    │      │    IB    │
    └──────────┘      └──────────┘      └──────────┘
          │                  │                  │
          ▼                  ▼                  ▼
    ┌──────────┐      ┌──────────┐      ┌──────────┐
    │   VAE    │      │DANN/ADDA │      │ SimCLR   │
    └──────────┘      └──────────┘      └──────────┘

6.2 统一数学形式

所有IB变体可以统一为:

其中 是辅助正则项:

  • :原始IB
  • :条件IB(解耦)
  • :对比IB

核心公式速查

变体公式
原始IB
VIB损失$\mathbb{E}[-\log q(y\mid z)] + \beta D_{KL}(q(z\mid x)\
条件IB
对比IB
自洽方程$p(t\mid x) \propto p(t) \exp(-\beta D_{KL}(p(y\mid x)\
InfoNCE下界

参考

相关文章


Footnotes

  1. Tishby, N., Pereira, F.C., & Bialek, W. (1999). “The Information Bottleneck Method”. Proceedings of the 37th Annual Allerton Conference on Communication, Control, and Computing.