自监督学习与测试时适应的结合
概述
测试时适应(Test-Time Adaptation, TTA)使深度学习模型能够在推理阶段适应动态环境变化,提升实际部署中的适用性。然而,现有TTA方法高度依赖于源域预训练模型的性能,这带来了一个被忽视的计算成本问题——源域预训练往往需要大量时间和计算资源。
本文基于arXiv:2506.23529的研究,系统探讨自监督测试时适应协议(Self-Supervised TTA Protocol),分析现有TTA方法在SSL模型上的失效原因,并提出一种无需源域预训练的协作学习框架AWS(Adapt With-out Source pretraining)。1
1. 问题背景与动机
1.1 传统TTA的依赖性
传统TTA方法假设存在一个在源域上通过监督学习预训练的模型 ,使用标记数据 进行训练。这种范式存在两个核心问题:
- 计算成本高昂:源域预训练时间可能远超测试时适应过程本身
- 跨域泛化受限:针对特定源域训练的模型难以直接泛化到其他目标域
| 预训练任务 | ImageNet预训练 | CIFAR100预训练 |
|---|---|---|
| 训练时间 | 1小时8分23秒 × 300 epochs | 9分7秒 × 200 epochs |
| SSL + 原型方法 | 36分25秒 | 1分25秒 |
| SSL + 原型(少样本) | 1分56秒 | 7秒 |
表1:源域预训练与传统TTA方法的计算成本对比1
1.2 SSL模型的潜力
自监督学习(SSL)模型如DINO、MoCo、iBOT等,通过大规模无标注数据进行预训练,展现出强大的零样本泛化能力。这些模型可直接作为TTA的初始化,有望解决传统方法面临的计算效率问题。
然而,研究发现现有TTA方法在SSL模型上的表现并不理想:
| 模型类型 | 基础准确率 | TENT提升 | CoTTA提升 | AWS提升 |
|---|---|---|---|---|
| 监督学习模型 | 83.6% | +4.8% | +1.0% | +16.4% |
| DINO | 63.1% | +2.6% | -1.9% | +16.2% |
| MoCo | 60.0% | +1.1% | -1.7% | +7.0% |
| iBOT | 65.9% | +0.9% | 0.0% | +19.9% |
表2:不同模型在ImageNetC上的TTA性能对比1
2. 现有TTA方法在SSL模型上的失效分析
2.1 熵最小化方法的失效
熵最小化(Entropy Minimization, EM) 是TTA的主流方法,其核心假设是:低熵预测往往对应高准确率。然而,在SSL模型上这一假设不再成立。
SSL模型表现出以下特点:
- 相同熵值下损失更高:SSL模型在同一熵水平上往往具有更高的损失
- 错误预测熵不减:SSL模型可能降低错误预测的熵,从而增加真实风险
- 缺乏判别性特征:无监督预训练没有针对分类任务优化特征表示
式(1):基于原型的分类器概率计算公式1
2.2 一致性正则化方法的失效
一致性正则化(Consistency Regularization, CR) 方法依赖伪标签进行知识迁移。然而,SSL模型生成的伪标签准确率较低,导致误差传播:
| 模型类型 | 目标域伪标签准确率 | 问题 |
|---|---|---|
| 监督学习模型 | ~60-70% | 相对可控 |
| SSL模型 | ~30-50% | 严重误差传播 |
表3:不同模型的伪标签准确率对比1
3. 自监督测试时适应协议
3.1 协议定义
自监督测试时适应协议(Self-Supervised TTA Protocol)是本文提出的新型适应范式,其核心特点:
| 组件 | 传统TTA | 自监督TTA |
|---|---|---|
| 预训练数据 | (有标签) | (无标签) |
| 训练过程 | 源域监督训练 | SSL无监督预训练 |
| 源域访问 | 需要完整访问 | 仅需前向传播 |
| 计算成本 | 高 | 低 |
表4:传统TTA与自监督TTA协议对比1
3.2 原型分类器
SSL模型缺乏针对各类的分类器,这是将SSL应用于TTA的主要挑战。原型分类器通过以下方式解决此问题:
计算类原型:
其中 是类别 的特征均值,作为该类的原型表示。
分类决策:基于余弦相似度的原型分类器,如式(1)所示。
优势:
- 仅需前向传播,无需梯度回传
- 计算效率高
- 可通过少样本快速构建
4. AWS协作学习框架
4.1 框架概述
AWS(Adapt With-out Source pretraining)框架通过三阶段协作学习整合SSL模型与TTA模型的优势:
┌─────────────────────────────────────────────────────────────┐
│ AWS 协作学习框架 │
├─────────────────────────────────────────────────────────────┤
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 对比学习 │ → │ 知识蒸馏 │ → │ 互学习 │ │
│ │ CL │ │ KD │ │ ML │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ ↓ ↓ ↓ │
│ 表示细化 泛化保持 协同适应 │
└─────────────────────────────────────────────────────────────┘
4.2 对比学习(Contrastive Learning)
核心思想:通过伪标签引导的对比损失,逐步细化目标域的表示。
近似正确对比学习:
- 正样本定义:若样本 和 的 top- 预测存在交集,则视为正样本对
- 负样本定义:若两样本的 top-()预测无交集,则视为负样本对
- 模糊样本:不应用对比损失
指示函数定义:
对比损失函数:
其中 是特征间的余弦相似度, 是批次大小。
关键优势:
- 利用SSL模型的初始分类能力
- 保持适应过程的稳定性
- 避免低质量伪标签的误差传播
4.3 知识蒸馏(Knowledge Distillation)
核心思想:将SSL模型的知识迁移到目标模型,保持泛化能力并缓解连续域偏移下的过拟合。
知识蒸馏损失:
其中 是归一化特征向量, 是Frobenius范数。
设计原理:
- 通过对齐归一化特征向量,确保稳定的知识迁移
- 保留SSL模型嵌入空间的几何结构
- 减少特征表示的旋转,保持预测一致性
4.4 互学习(Mutual Learning)
核心思想:SSL模型提供泛化能力,目标模型提供领域特定知识,通过协同学习整合两种优势。
互学习损失:
其中:
- :目标模型的概率分布
- :目标模型生成的伪标签
- :交叉熵
- :互信息1
互信息的作用:捕获样本之间的关系信息,利用SSL模型的表示能力。
4.5 总体损失函数
AWS的总体损失函数为三个组件的加权和:
其中 和 是平衡各组件的超参数。
典型参数设置:
- (ImageNetC和CIFAR10C)或 (CIFAR100C)
5. 实验结果与分析
5.1 实验设置
数据集:
- ImageNet → ImageNetC
- CIFAR10 → CIFAR10C
- CIFAR100 → CIFAR100C
SSL模型:
- DINO(ViT-B/16)
- MoCo v3(ViT-B/16)
- iBOT(ViT-B/16)
评估协议:遵循持续测试时适应(CTTA)协议,在15个目标域上顺序适应。
5.2 主要结果
ImageNet → ImageNetC
| 方法 | 平均错误率 | 相对基线提升 |
|---|---|---|
| No Adapt(基线) | 69.2% | 0.0% |
| TENT | 67.3% | +1.9% |
| CoTTA | 71.1% | -1.9% |
| AWS(DINO) | 53.0% | +16.2% |
| AWS(iBOT) | 48.1% | +19.9% |
表5:ImageNetC上的SSL模型TTA性能1
CIFAR基准
| 模型 | CIFAR10C | CIFAR100C |
|---|---|---|
| DINO基线 | 44.3% | 64.1% |
| AWS(DINO) | 26.8% | 50.6% |
| MoCo基线 | 42.2% | 64.2% |
| AWS(MoCo) | 40.7% | 62.1% |
| iBOT基线 | 48.0% | 65.6% |
| AWS(iBOT) | 30.1% | 50.2% |
表6:CIFAR数据集上的平均错误率1
5.3 消融实验
| 对比学习 | 知识蒸馏 | 互学习 | ImageNetC | CIFAR10C | CIFAR100C |
|---|---|---|---|---|---|
| ✗ | ✗ | ✗ | 55.8% | 28.2% | 35.4% |
| ✓ | ✗ | ✗ | 43.4% | 16.0% | 27.9% |
| ✓ | ✓ | ✗ | 40.6% | 11.2% | 22.2% |
| ✓ | ✓ | ✓ | 39.4% | 10.8% | 20.4% |
表7:各组件的贡献分析1
关键发现:
- 对比学习单独使用即可带来显著提升(+12.4% on ImageNetC)
- 知识蒸馏有效缓解灾难性遗忘
- 互学习进一步提升性能,达到最佳组合效果
5.4 特征可视化分析
通过t-SNE可视化分析,发现:
- 传统方法:保守更新策略(如仅更新归一化层)保持初始表示,但受限于源模型初始性能
- AWS方法:展现更清晰的决策边界,对SSL模型和监督学习模型均有效
6. PyTorch实现
以下是AWS框架的完整PyTorch实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
class AWSLoss(nn.Module):
"""
Adapt With-out Source pretraining (AWS) Loss
整合对比学习、知识蒸馏和互学习的协作学习框架
"""
def __init__(
self,
num_classes: int,
k: int = 1,
n: int = 5,
lambda_kd: float = 0.01,
lambda_ml: float = 0.4,
temperature: float = 0.1
):
super().__init__()
self.k = k # 正样本阈值
self.n = n # 负样本阈值
self.lambda_kd = lambda_kd
self.lambda_ml = lambda_ml
self.temperature = temperature
def get_topk_predictions(
self,
logits: torch.Tensor,
k: int
) -> torch.Tensor:
"""获取top-k预测的类别索引"""
return torch.topk(logits, k=k, dim=-1).indices
def compute_indicator(
self,
topk_i: torch.Tensor,
topk_j: torch.Tensor,
topn_i: torch.Tensor,
topn_j: torch.Tensor
) -> torch.Tensor:
"""
计算样本对的关系指示函数
Returns:
indicator: 正样本=1, 负样本=-1, 模糊样本=0
"""
batch_size = topk_i.size(0)
# 正样本:top-k预测有交集
# [batch_size, k] -> [batch_size, batch_size, k]
topk_i_expand = topk_i.unsqueeze(1).expand(-1, batch_size, -1)
topk_j_expand = topk_j.unsqueeze(0).expand(batch_size, -1, -1)
positive = (topk_i_expand == topk_j_expand).any(dim=-1).float() # [B, B]
# 负样本:top-n预测无交集
topn_i_expand = topn_i.unsqueeze(1).expand(-1, batch_size, -1)
topn_j_expand = topn_j.unsqueeze(0).expand(batch_size, -1, -1)
negative = (topn_i_expand == topn_j_expand).any(dim=-1).float() # [B, B]
# 综合指示函数
indicator = positive - negative # 正样本1,负样本-1,模糊0
return indicator
def contrastive_loss(
self,
features: torch.Tensor,
logits: torch.Tensor,
indicator: torch.Tensor
) -> torch.Tensor:
"""
对比学习损失
Args:
features: [B, D] 归一化后的特征
logits: [B, C] 分类logits
"""
# 计算余弦相似度矩阵
similarity = torch.matmul(features, features.T) / self.temperature
# 计算正样本和负样本的指示掩码
mask_positive = (indicator == 1).float()
mask_negative = (indicator == -1).float()
mask_ignore = (indicator == 0).float()
# InfoNCE损失
exp_sim = torch.exp(similarity)
log_prob = similarity - torch.log(exp_sim.sum(dim=-1, keepdim=True))
# 仅对正样本对计算损失
loss = -(mask_positive * log_prob).sum() / (mask_positive.sum() + 1e-8)
return loss
def knowledge_distillation_loss(
self,
target_features: torch.Tensor,
ssl_features: torch.Tensor
) -> torch.Tensor:
"""
知识蒸馏损失:对齐SSL模型和目标模型的特征表示
Args:
target_features: [B, D] 目标模型特征(归一化)
ssl_features: [B, D] SSL模型特征(归一化)
"""
return F.mse_loss(target_features, ssl_features)
def mutual_learning_loss(
self,
ssl_probs: torch.Tensor,
target_probs: torch.Tensor,
target_logits: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
互学习损失
Args:
ssl_probs: [B, C] SSL模型概率分布
target_probs: [B, C] 目标模型概率分布
target_logits: [B, C] 目标模型logits
Returns:
ssl_loss: 更新SSL模型的损失
target_loss: 更新目标模型的损失
"""
# SSL模型损失:使用目标模型的伪标签
pseudo_labels = target_logits.argmax(dim=-1)
ssl_loss = F.cross_entropy(
torch.log(ssl_probs + 1e-8),
pseudo_labels
)
# 目标模型损失:互信息(简化版本)
# I(p_t, p_ssl) ≈ -KL(p_t || p_ssl)
target_loss = F.kl_div(
torch.log(target_probs + 1e-8),
ssl_probs,
reduction='batchmean'
)
return ssl_loss, target_loss
def forward(
self,
target_features: torch.Tensor,
ssl_features: torch.Tensor,
target_logits: torch.Tensor,
ssl_logits: torch.Tensor
) -> Tuple[torch.Tensor, dict]:
"""
AWS总损失计算
Args:
target_features: [B, D] 目标模型特征
ssl_features: [B, D] SSL模型特征
target_logits: [B, C] 目标模型logits
ssl_logits: [B, C] SSL模型logits
Returns:
total_loss: 总损失
loss_dict: 各组件损失
"""
batch_size = target_features.size(0)
# 归一化特征
target_features_norm = F.normalize(target_features, p=2, dim=-1)
ssl_features_norm = F.normalize(ssl_features, p=2, dim=-1)
# 获取top-k和top-n预测
target_topk = self.get_topk_predictions(target_logits, self.k)
target_topn = self.get_topk_predictions(target_logits, self.n)
# 计算指示函数
indicator = self.compute_indicator(
target_topk, target_topk,
target_topn, target_topn
)
# 对比学习损失
loss_cl = self.contrastive_loss(
target_features_norm,
target_logits,
indicator
)
# 知识蒸馏损失
loss_kd = self.knowledge_distillation_loss(
target_features_norm,
ssl_features_norm
)
# 互学习损失
target_probs = F.softmax(target_logits, dim=-1)
ssl_probs = F.softmax(ssl_logits, dim=-1)
loss_ssl, loss_target = self.mutual_learning_loss(
ssl_probs, target_probs, target_logits
)
loss_ml = loss_ssl + loss_target
# 总损失
total_loss = loss_cl + self.lambda_kd * loss_kd + self.lambda_ml * loss_ml
loss_dict = {
'loss_total': total_loss,
'loss_cl': loss_cl,
'loss_kd': loss_kd,
'loss_ml': loss_ml,
'loss_ssl': loss_ssl,
'loss_target': loss_target
}
return total_loss, loss_dict
class PrototypeClassifier(nn.Module):
"""
原型分类器:无需反向传播的分类器
通过计算类原型并使用余弦相似度进行分类
"""
def __init__(self, feat_dim: int, num_classes: int, logit_scale: float = 10.0):
super().__init__()
self.feat_dim = feat_dim
self.num_classes = num_classes
self.logit_scale = logit_scale
# 可学习的类原型(用于在线更新)
self.register_buffer('prototypes', torch.zeros(num_classes, feat_dim))
self.register_buffer('class_counts', torch.zeros(num_classes))
def update_prototypes(
self,
features: torch.Tensor,
labels: torch.Tensor,
momentum: float = 0.0
):
"""
更新类原型
Args:
features: [B, D] 特征向量
labels: [B] 类别标签
momentum: 动量更新系数(0表示完全替换)
"""
for feat, label in zip(features, labels):
label = label.item()
if momentum > 0:
self.prototypes[label] = (
momentum * self.prototypes[label] +
(1 - momentum) * feat.detach()
)
else:
self.prototypes[label] = feat.detach()
self.class_counts[label] += 1
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
前向传播:计算分类logits
Args:
features: [B, D] 输入特征
Returns:
logits: [B, C] 分类logits
"""
# 归一化特征和原型
features_norm = F.normalize(features, p=2, dim=-1)
prototypes_norm = F.normalize(self.prototypes, p=2, dim=-1)
# 余弦相似度
similarity = torch.matmul(features_norm, prototypes_norm.T)
# 缩放
logits = similarity * self.logit_scale
return logits
class AWSTestTimeAdaptation:
"""
AWS测试时适应主类
整合SSL模型与目标模型的协作学习
"""
def __init__(
self,
ssl_model: nn.Module,
target_model: nn.Module,
num_classes: int,
device: torch.device,
k: int = 1,
n: int = 5,
lambda_kd: float = 0.01,
lambda_ml: float = 0.4,
lr: float = 1e-4,
momentum: float = 0.9
):
self.device = device
self.ssl_model = ssl_model.to(device)
self.target_model = target_model.to(device)
# 损失函数
self.criterion = AWSLoss(
num_classes=num_classes,
k=k,
n=n,
lambda_kd=lambda_kd,
lambda_ml=lambda_ml
)
# 原型分类器
self.prototype_classifier = PrototypeClassifier(
feat_dim=ssl_model.embed_dim,
num_classes=num_classes
).to(device)
# 优化器
self.optimizer = torch.optim.SGD(
target_model.parameters(),
lr=lr,
momentum=momentum
)
@torch.no_grad()
def compute_ssl_features(self, images: torch.Tensor) -> torch.Tensor:
"""计算SSL模型特征"""
self.ssl_model.eval()
return self.ssl_model(images)
def adapt_batch(self, images: torch.Tensor) -> dict:
"""
对单个批次进行适应
Args:
images: [B, C, H, W] 输入图像
Returns:
results: 包含loss和各组件损失的字典
"""
# 前向传播
ssl_features = self.compute_ssl_features(images)
target_features = self.target_model(images)
# 分类
ssl_logits = self.prototype_classifier(ssl_features)
target_logits = self.prototype_classifier(target_features)
# 计算损失
total_loss, loss_dict = self.criterion(
target_features=target_features,
ssl_features=ssl_features,
target_logits=target_logits,
ssl_logits=ssl_logits
)
# 反向传播
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
return loss_dict
def evaluate(self, images: torch.Tensor) -> torch.Tensor:
"""
评估模式:仅使用目标模型
Args:
images: [B, C, H, W] 输入图像
Returns:
predictions: 预测类别
"""
self.target_model.eval()
with torch.no_grad():
target_features = self.target_model(images)
logits = self.prototype_classifier(target_features)
predictions = logits.argmax(dim=-1)
return predictions
# 使用示例
def demo_usage():
"""AWS框架使用示例"""
import torchvision.models as models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 假设SSL模型已预训练
ssl_backbone = models.vit_b_16(pretrained=False)
target_backbone = models.vit_b_16(pretrained=False)
# 创建TTA系统
aws_tta = AWSTestTimeAdaptation(
ssl_model=ssl_backbone,
target_model=target_backbone,
num_classes=1000,
device=device,
k=1,
n=5,
lambda_kd=0.01,
lambda_ml=0.4,
lr=1e-4 * 64 / 64, # 随batch size缩放
momentum=0.9
)
# 初始化原型分类器(使用少量样本)
# aws_tta.prototype_classifier.update_prototypes(features, labels)
# 测试适应过程
dummy_images = torch.randn(16, 3, 224, 224).to(device)
loss_dict = aws_tta.adapt_batch(dummy_images)
print("适应损失:")
for key, value in loss_dict.items():
print(f" {key}: {value.item():.4f}")
if __name__ == "__main__":
demo_usage()7. 与相关工作的联系
7.1 与对比学习的关系
AWS框架中的对比学习组件与对比学习理论密切相关。区别在于:
- 传统对比学习:在预训练阶段学习通用表示
- AWS对比学习:在测试时细化领域特定表示,利用伪标签关系
7.2 与知识蒸馏的关系
知识蒸馏组件继承自知识蒸馏基础的核心思想,但应用于:
- 传统知识蒸馏:从大模型到小模型的知识迁移
- AWS知识蒸馏:从SSL模型到目标模型的知识保持
7.3 与测试时计算扩展的关系
AWS方法与测试时计算扩展有着相似的哲学——在推理阶段投入额外计算以提升性能:
| 维度 | 测试时计算扩展 | AWS |
|---|---|---|
| 计算类型 | 推理时的多次前向传播 | 在线学习更新 |
| 目标 | 提升推理质量 | 适应分布偏移 |
| 应用场景 | LLM推理 | 视觉模型适应 |
8. 结论与展望
8.1 主要贡献
- 首次揭示了传统TTA方法在SSL模型上的失效问题
- 提出自监督测试时适应协议,显著降低计算成本
- 设计AWS协作学习框架,整合对比学习、知识蒸馏和互学习
- 验证了方法在多种SSL模型(DINO、MoCo、iBOT)上的有效性
8.2 局限性
- 仍需要一定量的目标域数据流
- 超参数对不同数据集可能需要调整
- 对极端分布偏移的鲁棒性有待进一步验证
8.3 未来方向
- 探索更高效的原型更新策略
- 将AWS扩展到更多SSL架构(如MAE、Jepa)
- 结合测试时计算扩展的混合方法