自监督学习与测试时适应的结合

概述

测试时适应(Test-Time Adaptation, TTA)使深度学习模型能够在推理阶段适应动态环境变化,提升实际部署中的适用性。然而,现有TTA方法高度依赖于源域预训练模型的性能,这带来了一个被忽视的计算成本问题——源域预训练往往需要大量时间和计算资源。

本文基于arXiv:2506.23529的研究,系统探讨自监督测试时适应协议(Self-Supervised TTA Protocol),分析现有TTA方法在SSL模型上的失效原因,并提出一种无需源域预训练的协作学习框架AWS(Adapt With-out Source pretraining)。1


1. 问题背景与动机

1.1 传统TTA的依赖性

传统TTA方法假设存在一个在源域上通过监督学习预训练的模型 ,使用标记数据 进行训练。这种范式存在两个核心问题:

  1. 计算成本高昂:源域预训练时间可能远超测试时适应过程本身
  2. 跨域泛化受限:针对特定源域训练的模型难以直接泛化到其他目标域
预训练任务ImageNet预训练CIFAR100预训练
训练时间1小时8分23秒 × 300 epochs9分7秒 × 200 epochs
SSL + 原型方法36分25秒1分25秒
SSL + 原型(少样本)1分56秒7秒

表1:源域预训练与传统TTA方法的计算成本对比1

1.2 SSL模型的潜力

自监督学习(SSL)模型如DINO、MoCo、iBOT等,通过大规模无标注数据进行预训练,展现出强大的零样本泛化能力。这些模型可直接作为TTA的初始化,有望解决传统方法面临的计算效率问题。

然而,研究发现现有TTA方法在SSL模型上的表现并不理想

模型类型基础准确率TENT提升CoTTA提升AWS提升
监督学习模型83.6%+4.8%+1.0%+16.4%
DINO63.1%+2.6%-1.9%+16.2%
MoCo60.0%+1.1%-1.7%+7.0%
iBOT65.9%+0.9%0.0%+19.9%

表2:不同模型在ImageNetC上的TTA性能对比1


2. 现有TTA方法在SSL模型上的失效分析

2.1 熵最小化方法的失效

熵最小化(Entropy Minimization, EM) 是TTA的主流方法,其核心假设是:低熵预测往往对应高准确率。然而,在SSL模型上这一假设不再成立。

SSL模型表现出以下特点:

  1. 相同熵值下损失更高:SSL模型在同一熵水平上往往具有更高的损失
  2. 错误预测熵不减:SSL模型可能降低错误预测的熵,从而增加真实风险
  3. 缺乏判别性特征:无监督预训练没有针对分类任务优化特征表示

式(1):基于原型的分类器概率计算公式1

2.2 一致性正则化方法的失效

一致性正则化(Consistency Regularization, CR) 方法依赖伪标签进行知识迁移。然而,SSL模型生成的伪标签准确率较低,导致误差传播:

模型类型目标域伪标签准确率问题
监督学习模型~60-70%相对可控
SSL模型~30-50%严重误差传播

表3:不同模型的伪标签准确率对比1


3. 自监督测试时适应协议

3.1 协议定义

自监督测试时适应协议(Self-Supervised TTA Protocol)是本文提出的新型适应范式,其核心特点:

组件传统TTA自监督TTA
预训练数据(有标签)(无标签)
训练过程源域监督训练SSL无监督预训练
源域访问需要完整访问仅需前向传播
计算成本

表4:传统TTA与自监督TTA协议对比1

3.2 原型分类器

SSL模型缺乏针对各类的分类器,这是将SSL应用于TTA的主要挑战。原型分类器通过以下方式解决此问题:

计算类原型

其中 是类别 的特征均值,作为该类的原型表示。

分类决策:基于余弦相似度的原型分类器,如式(1)所示。

优势

  • 仅需前向传播,无需梯度回传
  • 计算效率高
  • 可通过少样本快速构建

4. AWS协作学习框架

4.1 框架概述

AWS(Adapt With-out Source pretraining)框架通过三阶段协作学习整合SSL模型与TTA模型的优势:

┌─────────────────────────────────────────────────────────────┐
│                    AWS 协作学习框架                           │
├─────────────────────────────────────────────────────────────┤
│  ┌─────────────┐    ┌─────────────┐    ┌─────────────┐      │
│  │  对比学习   │ →  │  知识蒸馏   │ →  │  互学习     │      │
│  │     CL     │    │     KD      │    │     ML      │      │
│  └─────────────┘    └─────────────┘    └─────────────┘      │
│       ↓                   ↓                   ↓             │
│  表示细化          泛化保持          协同适应                │
└─────────────────────────────────────────────────────────────┘

4.2 对比学习(Contrastive Learning)

核心思想:通过伪标签引导的对比损失,逐步细化目标域的表示。

近似正确对比学习

  1. 正样本定义:若样本 的 top- 预测存在交集,则视为正样本对
  2. 负样本定义:若两样本的 top-)预测无交集,则视为负样本对
  3. 模糊样本:不应用对比损失

指示函数定义

对比损失函数

其中 是特征间的余弦相似度, 是批次大小。

关键优势

  • 利用SSL模型的初始分类能力
  • 保持适应过程的稳定性
  • 避免低质量伪标签的误差传播

4.3 知识蒸馏(Knowledge Distillation)

核心思想:将SSL模型的知识迁移到目标模型,保持泛化能力并缓解连续域偏移下的过拟合。

知识蒸馏损失

其中 是归一化特征向量, 是Frobenius范数。

设计原理

  1. 通过对齐归一化特征向量,确保稳定的知识迁移
  2. 保留SSL模型嵌入空间的几何结构
  3. 减少特征表示的旋转,保持预测一致性

4.4 互学习(Mutual Learning)

核心思想:SSL模型提供泛化能力,目标模型提供领域特定知识,通过协同学习整合两种优势。

互学习损失

其中:

  • :目标模型的概率分布
  • :目标模型生成的伪标签
  • :交叉熵
  • :互信息1

互信息的作用:捕获样本之间的关系信息,利用SSL模型的表示能力。

4.5 总体损失函数

AWS的总体损失函数为三个组件的加权和:

其中 是平衡各组件的超参数。

典型参数设置

  • (ImageNetC和CIFAR10C)或 (CIFAR100C)

5. 实验结果与分析

5.1 实验设置

数据集

  • ImageNet → ImageNetC
  • CIFAR10 → CIFAR10C
  • CIFAR100 → CIFAR100C

SSL模型

  • DINO(ViT-B/16)
  • MoCo v3(ViT-B/16)
  • iBOT(ViT-B/16)

评估协议:遵循持续测试时适应(CTTA)协议,在15个目标域上顺序适应。

5.2 主要结果

ImageNet → ImageNetC

方法平均错误率相对基线提升
No Adapt(基线)69.2%0.0%
TENT67.3%+1.9%
CoTTA71.1%-1.9%
AWS(DINO)53.0%+16.2%
AWS(iBOT)48.1%+19.9%

表5:ImageNetC上的SSL模型TTA性能1

CIFAR基准

模型CIFAR10CCIFAR100C
DINO基线44.3%64.1%
AWS(DINO)26.8%50.6%
MoCo基线42.2%64.2%
AWS(MoCo)40.7%62.1%
iBOT基线48.0%65.6%
AWS(iBOT)30.1%50.2%

表6:CIFAR数据集上的平均错误率1

5.3 消融实验

对比学习知识蒸馏互学习ImageNetCCIFAR10CCIFAR100C
55.8%28.2%35.4%
43.4%16.0%27.9%
40.6%11.2%22.2%
39.4%10.8%20.4%

表7:各组件的贡献分析1

关键发现

  1. 对比学习单独使用即可带来显著提升(+12.4% on ImageNetC)
  2. 知识蒸馏有效缓解灾难性遗忘
  3. 互学习进一步提升性能,达到最佳组合效果

5.4 特征可视化分析

通过t-SNE可视化分析,发现:

  • 传统方法:保守更新策略(如仅更新归一化层)保持初始表示,但受限于源模型初始性能
  • AWS方法:展现更清晰的决策边界,对SSL模型和监督学习模型均有效

6. PyTorch实现

以下是AWS框架的完整PyTorch实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
 
class AWSLoss(nn.Module):
    """
    Adapt With-out Source pretraining (AWS) Loss
    
    整合对比学习、知识蒸馏和互学习的协作学习框架
    """
    
    def __init__(
        self,
        num_classes: int,
        k: int = 1,
        n: int = 5,
        lambda_kd: float = 0.01,
        lambda_ml: float = 0.4,
        temperature: float = 0.1
    ):
        super().__init__()
        self.k = k  # 正样本阈值
        self.n = n  # 负样本阈值
        self.lambda_kd = lambda_kd
        self.lambda_ml = lambda_ml
        self.temperature = temperature
    
    def get_topk_predictions(
        self, 
        logits: torch.Tensor, 
        k: int
    ) -> torch.Tensor:
        """获取top-k预测的类别索引"""
        return torch.topk(logits, k=k, dim=-1).indices
    
    def compute_indicator(
        self, 
        topk_i: torch.Tensor, 
        topk_j: torch.Tensor,
        topn_i: torch.Tensor,
        topn_j: torch.Tensor
    ) -> torch.Tensor:
        """
        计算样本对的关系指示函数
        
        Returns:
            indicator: 正样本=1, 负样本=-1, 模糊样本=0
        """
        batch_size = topk_i.size(0)
        
        # 正样本:top-k预测有交集
        # [batch_size, k] -> [batch_size, batch_size, k]
        topk_i_expand = topk_i.unsqueeze(1).expand(-1, batch_size, -1)
        topk_j_expand = topk_j.unsqueeze(0).expand(batch_size, -1, -1)
        positive = (topk_i_expand == topk_j_expand).any(dim=-1).float()  # [B, B]
        
        # 负样本:top-n预测无交集
        topn_i_expand = topn_i.unsqueeze(1).expand(-1, batch_size, -1)
        topn_j_expand = topn_j.unsqueeze(0).expand(batch_size, -1, -1)
        negative = (topn_i_expand == topn_j_expand).any(dim=-1).float()  # [B, B]
        
        # 综合指示函数
        indicator = positive - negative  # 正样本1,负样本-1,模糊0
        return indicator
    
    def contrastive_loss(
        self,
        features: torch.Tensor,
        logits: torch.Tensor,
        indicator: torch.Tensor
    ) -> torch.Tensor:
        """
        对比学习损失
        
        Args:
            features: [B, D] 归一化后的特征
            logits: [B, C] 分类logits
        """
        # 计算余弦相似度矩阵
        similarity = torch.matmul(features, features.T) / self.temperature
        
        # 计算正样本和负样本的指示掩码
        mask_positive = (indicator == 1).float()
        mask_negative = (indicator == -1).float()
        mask_ignore = (indicator == 0).float()
        
        # InfoNCE损失
        exp_sim = torch.exp(similarity)
        log_prob = similarity - torch.log(exp_sim.sum(dim=-1, keepdim=True))
        
        # 仅对正样本对计算损失
        loss = -(mask_positive * log_prob).sum() / (mask_positive.sum() + 1e-8)
        
        return loss
    
    def knowledge_distillation_loss(
        self,
        target_features: torch.Tensor,
        ssl_features: torch.Tensor
    ) -> torch.Tensor:
        """
        知识蒸馏损失:对齐SSL模型和目标模型的特征表示
        
        Args:
            target_features: [B, D] 目标模型特征(归一化)
            ssl_features: [B, D] SSL模型特征(归一化)
        """
        return F.mse_loss(target_features, ssl_features)
    
    def mutual_learning_loss(
        self,
        ssl_probs: torch.Tensor,
        target_probs: torch.Tensor,
        target_logits: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        互学习损失
        
        Args:
            ssl_probs: [B, C] SSL模型概率分布
            target_probs: [B, C] 目标模型概率分布
            target_logits: [B, C] 目标模型logits
        
        Returns:
            ssl_loss: 更新SSL模型的损失
            target_loss: 更新目标模型的损失
        """
        # SSL模型损失:使用目标模型的伪标签
        pseudo_labels = target_logits.argmax(dim=-1)
        ssl_loss = F.cross_entropy(
            torch.log(ssl_probs + 1e-8), 
            pseudo_labels
        )
        
        # 目标模型损失:互信息(简化版本)
        # I(p_t, p_ssl) ≈ -KL(p_t || p_ssl)
        target_loss = F.kl_div(
            torch.log(target_probs + 1e-8),
            ssl_probs,
            reduction='batchmean'
        )
        
        return ssl_loss, target_loss
    
    def forward(
        self,
        target_features: torch.Tensor,
        ssl_features: torch.Tensor,
        target_logits: torch.Tensor,
        ssl_logits: torch.Tensor
    ) -> Tuple[torch.Tensor, dict]:
        """
        AWS总损失计算
        
        Args:
            target_features: [B, D] 目标模型特征
            ssl_features: [B, D] SSL模型特征
            target_logits: [B, C] 目标模型logits
            ssl_logits: [B, C] SSL模型logits
        
        Returns:
            total_loss: 总损失
            loss_dict: 各组件损失
        """
        batch_size = target_features.size(0)
        
        # 归一化特征
        target_features_norm = F.normalize(target_features, p=2, dim=-1)
        ssl_features_norm = F.normalize(ssl_features, p=2, dim=-1)
        
        # 获取top-k和top-n预测
        target_topk = self.get_topk_predictions(target_logits, self.k)
        target_topn = self.get_topk_predictions(target_logits, self.n)
        
        # 计算指示函数
        indicator = self.compute_indicator(
            target_topk, target_topk,
            target_topn, target_topn
        )
        
        # 对比学习损失
        loss_cl = self.contrastive_loss(
            target_features_norm, 
            target_logits, 
            indicator
        )
        
        # 知识蒸馏损失
        loss_kd = self.knowledge_distillation_loss(
            target_features_norm,
            ssl_features_norm
        )
        
        # 互学习损失
        target_probs = F.softmax(target_logits, dim=-1)
        ssl_probs = F.softmax(ssl_logits, dim=-1)
        loss_ssl, loss_target = self.mutual_learning_loss(
            ssl_probs, target_probs, target_logits
        )
        loss_ml = loss_ssl + loss_target
        
        # 总损失
        total_loss = loss_cl + self.lambda_kd * loss_kd + self.lambda_ml * loss_ml
        
        loss_dict = {
            'loss_total': total_loss,
            'loss_cl': loss_cl,
            'loss_kd': loss_kd,
            'loss_ml': loss_ml,
            'loss_ssl': loss_ssl,
            'loss_target': loss_target
        }
        
        return total_loss, loss_dict
 
 
class PrototypeClassifier(nn.Module):
    """
    原型分类器:无需反向传播的分类器
    
    通过计算类原型并使用余弦相似度进行分类
    """
    
    def __init__(self, feat_dim: int, num_classes: int, logit_scale: float = 10.0):
        super().__init__()
        self.feat_dim = feat_dim
        self.num_classes = num_classes
        self.logit_scale = logit_scale
        
        # 可学习的类原型(用于在线更新)
        self.register_buffer('prototypes', torch.zeros(num_classes, feat_dim))
        self.register_buffer('class_counts', torch.zeros(num_classes))
    
    def update_prototypes(
        self, 
        features: torch.Tensor, 
        labels: torch.Tensor,
        momentum: float = 0.0
    ):
        """
        更新类原型
        
        Args:
            features: [B, D] 特征向量
            labels: [B] 类别标签
            momentum: 动量更新系数(0表示完全替换)
        """
        for feat, label in zip(features, labels):
            label = label.item()
            if momentum > 0:
                self.prototypes[label] = (
                    momentum * self.prototypes[label] + 
                    (1 - momentum) * feat.detach()
                )
            else:
                self.prototypes[label] = feat.detach()
            self.class_counts[label] += 1
    
    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """
        前向传播:计算分类logits
        
        Args:
            features: [B, D] 输入特征
        
        Returns:
            logits: [B, C] 分类logits
        """
        # 归一化特征和原型
        features_norm = F.normalize(features, p=2, dim=-1)
        prototypes_norm = F.normalize(self.prototypes, p=2, dim=-1)
        
        # 余弦相似度
        similarity = torch.matmul(features_norm, prototypes_norm.T)
        
        # 缩放
        logits = similarity * self.logit_scale
        
        return logits
 
 
class AWSTestTimeAdaptation:
    """
    AWS测试时适应主类
    
    整合SSL模型与目标模型的协作学习
    """
    
    def __init__(
        self,
        ssl_model: nn.Module,
        target_model: nn.Module,
        num_classes: int,
        device: torch.device,
        k: int = 1,
        n: int = 5,
        lambda_kd: float = 0.01,
        lambda_ml: float = 0.4,
        lr: float = 1e-4,
        momentum: float = 0.9
    ):
        self.device = device
        self.ssl_model = ssl_model.to(device)
        self.target_model = target_model.to(device)
        
        # 损失函数
        self.criterion = AWSLoss(
            num_classes=num_classes,
            k=k,
            n=n,
            lambda_kd=lambda_kd,
            lambda_ml=lambda_ml
        )
        
        # 原型分类器
        self.prototype_classifier = PrototypeClassifier(
            feat_dim=ssl_model.embed_dim,
            num_classes=num_classes
        ).to(device)
        
        # 优化器
        self.optimizer = torch.optim.SGD(
            target_model.parameters(),
            lr=lr,
            momentum=momentum
        )
    
    @torch.no_grad()
    def compute_ssl_features(self, images: torch.Tensor) -> torch.Tensor:
        """计算SSL模型特征"""
        self.ssl_model.eval()
        return self.ssl_model(images)
    
    def adapt_batch(self, images: torch.Tensor) -> dict:
        """
        对单个批次进行适应
        
        Args:
            images: [B, C, H, W] 输入图像
        
        Returns:
            results: 包含loss和各组件损失的字典
        """
        # 前向传播
        ssl_features = self.compute_ssl_features(images)
        target_features = self.target_model(images)
        
        # 分类
        ssl_logits = self.prototype_classifier(ssl_features)
        target_logits = self.prototype_classifier(target_features)
        
        # 计算损失
        total_loss, loss_dict = self.criterion(
            target_features=target_features,
            ssl_features=ssl_features,
            target_logits=target_logits,
            ssl_logits=ssl_logits
        )
        
        # 反向传播
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        
        return loss_dict
    
    def evaluate(self, images: torch.Tensor) -> torch.Tensor:
        """
        评估模式:仅使用目标模型
        
        Args:
            images: [B, C, H, W] 输入图像
        
        Returns:
            predictions: 预测类别
        """
        self.target_model.eval()
        with torch.no_grad():
            target_features = self.target_model(images)
            logits = self.prototype_classifier(target_features)
            predictions = logits.argmax(dim=-1)
        return predictions
 
 
# 使用示例
def demo_usage():
    """AWS框架使用示例"""
    import torchvision.models as models
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 假设SSL模型已预训练
    ssl_backbone = models.vit_b_16(pretrained=False)
    target_backbone = models.vit_b_16(pretrained=False)
    
    # 创建TTA系统
    aws_tta = AWSTestTimeAdaptation(
        ssl_model=ssl_backbone,
        target_model=target_backbone,
        num_classes=1000,
        device=device,
        k=1,
        n=5,
        lambda_kd=0.01,
        lambda_ml=0.4,
        lr=1e-4 * 64 / 64,  # 随batch size缩放
        momentum=0.9
    )
    
    # 初始化原型分类器(使用少量样本)
    # aws_tta.prototype_classifier.update_prototypes(features, labels)
    
    # 测试适应过程
    dummy_images = torch.randn(16, 3, 224, 224).to(device)
    loss_dict = aws_tta.adapt_batch(dummy_images)
    
    print("适应损失:")
    for key, value in loss_dict.items():
        print(f"  {key}: {value.item():.4f}")
 
 
if __name__ == "__main__":
    demo_usage()

7. 与相关工作的联系

7.1 与对比学习的关系

AWS框架中的对比学习组件与对比学习理论密切相关。区别在于:

  • 传统对比学习:在预训练阶段学习通用表示
  • AWS对比学习:在测试时细化领域特定表示,利用伪标签关系

7.2 与知识蒸馏的关系

知识蒸馏组件继承自知识蒸馏基础的核心思想,但应用于:

  • 传统知识蒸馏:从大模型到小模型的知识迁移
  • AWS知识蒸馏:从SSL模型到目标模型的知识保持

7.3 与测试时计算扩展的关系

AWS方法与测试时计算扩展有着相似的哲学——在推理阶段投入额外计算以提升性能:

维度测试时计算扩展AWS
计算类型推理时的多次前向传播在线学习更新
目标提升推理质量适应分布偏移
应用场景LLM推理视觉模型适应

8. 结论与展望

8.1 主要贡献

  1. 首次揭示了传统TTA方法在SSL模型上的失效问题
  2. 提出自监督测试时适应协议,显著降低计算成本
  3. 设计AWS协作学习框架,整合对比学习、知识蒸馏和互学习
  4. 验证了方法在多种SSL模型(DINO、MoCo、iBOT)上的有效性

8.2 局限性

  1. 仍需要一定量的目标域数据流
  2. 超参数对不同数据集可能需要调整
  3. 对极端分布偏移的鲁棒性有待进一步验证

8.3 未来方向

  1. 探索更高效的原型更新策略
  2. 将AWS扩展到更多SSL架构(如MAE、Jepa)
  3. 结合测试时计算扩展的混合方法

参考

Footnotes

  1. Jisu Han, Jihee Park, Dongyoon Han, Wonjun Hwang. “When Test-Time Adaptation Meets Self-Supervised Models.” arXiv:2506.23529, 2025. 2 3 4 5 6 7 8 9 10