C-TTA:连续测试时适应
概述
C-TTA(Continuous Test-Time Adaptation)针对视觉语言模型(VLM)提出了一种连续测试时适应框架。其核心创新在于通过持续更新的目标原型来适应整个目标域分布,解决了现有方法只处理静态分布或小部分样本的问题。
核心贡献
- 无需回传:仅更新目标原型,无需反向传播
- 无需大缓存:仅存储C×d维的原型矩阵(d为特征维度)
- 全样本利用:通过持续更新机制利用所有测试样本
- 多模态扩展:可扩展至3D VLM(点云分析)
问题定义
现有TTA方法的局限性:
| 方法类型 | 代表方法 | 局限性 |
|---|---|---|
| Instance-wise TTA | TPT, EATA | 假设静态目标域分布 |
| Episodic-wise TTA | TDA, BCA | 仅使用高置信度小部分样本(~14%) |
| Continuous TTA | C-TTA | 利用所有样本累积知识 |
方法详解
1. 问题形式化
给定预训练的VLM(如CLIP)和无标签目标数据分布 ,目标是最大化:
C-TTA的核心假设是:目标原型可以累积来自所有测试样本的领域知识。
2. 核心机制
目标原型定义
其中:
- :类别数
- :视觉特征维度
- :时间步t时类别j的目标原型
自适应融合权重
其中 是类别置信度, 是温度参数。
物理意义:
- 当 很高(模型对类别j很自信)时,,完全信任当前预测
- 当 很低(模型不确定)时,,更多依赖历史原型
原型更新规则
其中 是当前图像的视觉特征。
最终预测
其中:
3. 算法流程
# C-TTA 核心算法
def c_tta(model, target_images, C, d, lambda_=1.0, h=0.1, tau=0.07):
"""
Continuous Test-Time Adaptation for VLM
Args:
model: Pre-trained CLIP model
target_images: Stream of test images
C: Number of classes
d: Feature dimension
lambda_: Balance weight for target prototype
h: Temperature for confidence weighting
tau: Temperature for prototype similarity
"""
# Initialize target prototypes with source statistics
P_t = initialize_prototypes(model, C, d)
for x in target_images:
# Extract visual features
F_g = model.encode_image(x) # (1, d)
# Compute CLIP prediction
p_clip = model.predict(x) # (1, C)
# Compute adaptive weights
beta = 1 - torch.exp(-p_clip / h) # (1, C)
# Update prototypes
for j in range(C):
P_t[j] = (1 - beta[j]) * P_t[j] + beta[j] * F_g
# Compute target prototype prediction
p_target = prototype_prediction(F_g, P_t, tau)
# Combined prediction
p_final = p_clip + lambda_ * p_target
return P_t, p_final技术优势分析
与现有方法对比
| 维度 | TPT | TDA | BCA | C-TTA |
|---|---|---|---|---|
| 适应单元 | 单样本 | 小批次 | 小批次 | 全分布 |
| 参数更新 | BN统计 | 全部 | BN统计 | 原型 |
| 回传需求 | ❌ | ✅ | ❌ | ❌ |
| 缓存需求 | ❌ | ✅ | ✅ | ❌ |
| 内存开销 | O(1) | O(BN×D) | O(BN×D) | O(C×d) |
| 样本利用率 | 100% | ~14% | ~14% | 100% |
| 推理加速 | 1× | 0.3× | 0.4× | 5.7× |
内存复杂度分析
假设:
- (批次大小)
- (ViT的patch数)
- (隐藏维度)
| 方法 | 内存开销 | C-TTA压缩比 |
|---|---|---|
| TDA | 150× | |
| BCA | 150× | |
| C-TTA | 1× |
实验结果
Cross-Dataset泛化(ImageNet→其他数据集)
| 方法 | IN-A | IN-R | IN-V2 | 平均 |
|---|---|---|---|---|
| CLIP (基线) | 56.2% | 73.4% | 67.3% | 65.6% |
| TPT | 59.8% | 75.1% | 69.2% | 68.0% |
| TDA | 61.3% | 76.2% | 70.1% | 69.2% |
| BCA | 62.1% | 76.8% | 70.5% | 69.8% |
| C-TTA | 64.5% | 78.2% | 72.1% | 71.6% |
Domain Generalization(OOD设置)
| 方法 | ImageNet-C (平均) | CIFAR-C | 平均OOD |
|---|---|---|---|
| CLIP | 52.1% | 62.4% | 57.3% |
| TPT | 58.7% | 68.2% | 63.5% |
| TCA | 61.2% | 71.1% | 66.2% |
| C-TTA | 64.8% | 73.6% | 69.2% |
3D VLM扩展(点云分析)
| 数据集 | 任务 | CLIP基线 | C-TTA |
|---|---|---|---|
| ScanObjectNN | 分类 | 48.2% | 63.1% |
| ModelNet40 | 分类 | 56.7% | 71.3% |
| ShapeNetPart | 分割 | 62.4% | 75.8% |
效率对比
| 方法 | GPU显存 | 推理时间(相对) |
|---|---|---|
| TDA | 8.2GB | 5.7× |
| TPT | 4.1GB | 1.2× |
| C-TTA | 3.8GB | 1.0× |
消融实验
原型更新机制
| 更新策略 | 公式 | 效果 |
|---|---|---|
| 固定权重 | 68.2% | |
| 置信度加权(本文) | 71.6% | |
| 纯插值 | 67.8% |
温度参数h的影响
| h值 | IN-A准确率 | 稳定性 |
|---|---|---|
| 0.01 | 62.1% | 高 |
| 0.1 | 64.5% | 高 |
| 0.5 | 63.8% | 中 |
| 1.0 | 61.2% | 低 |
推荐值:
类别数C的影响
C-TTA在不同类别数下的表现:
| C | IN-A | IN-R | IN-V2 |
|---|---|---|---|
| 100 | 63.2% | 77.5% | 71.0% |
| 500 | 64.1% | 77.9% | 71.5% |
| 1000 | 64.5% | 78.2% | 72.1% |
PyTorch实现
import torch
import torch.nn.functional as F
from torch import nn
class CTTA(nn.Module):
"""
Continuous Test-Time Adaptation for Vision-Language Models
"""
def __init__(self, feature_dim: int, num_classes: int,
lambda_: float = 1.0, h: float = 0.1, tau: float = 0.07):
super().__init__()
self.num_classes = num_classes
self.feature_dim = feature_dim
self.lambda_ = lambda_
self.h = h
self.tau = tau
# Target prototypes (trainable buffers, not parameters)
self.register_buffer('P_t', torch.zeros(num_classes, feature_dim))
self.initialized = False
def initialize_prototypes(self, model, loader):
"""Initialize prototypes using source data statistics"""
features = []
for images, _ in loader:
with torch.no_grad():
feat = model.encode_image(images)
features.append(feat)
features = torch.cat(features, dim=0)
# Initialize with class-wise mean features
# This is a simplified version; actual implementation
# would use class labels from source data
pass
def update_prototype(self, features: torch.Tensor,
clip_probs: torch.Tensor):
"""
Update target prototypes based on current sample
Args:
features: (batch_size, feature_dim) visual features
clip_probs: (batch_size, num_classes) CLIP predictions
"""
batch_size = features.shape[0]
# Adaptive weight: 1 - exp(-s/h)
beta = 1 - torch.exp(-clip_probs / self.h) # (B, C)
# Expand dimensions for broadcasting
beta = beta.unsqueeze(-1) # (B, C, 1)
features = features.unsqueeze(1) # (B, 1, d)
# Update prototypes (in-place for test-time)
with torch.no_grad():
update = (beta * features).mean(dim=0) # (C, d)
self.P_t.mul_(1 - beta.mean(dim=0).unsqueeze(-1))
self.P_t.add_(update)
def prototype_prediction(self, features: torch.Tensor) -> torch.Tensor:
"""
Compute prediction based on target prototypes
Args:
features: (batch_size, feature_dim)
Returns:
probs: (batch_size, num_classes)
"""
# Normalize features and prototypes
features = F.normalize(features, p=2, dim=-1)
prototypes = F.normalize(self.P_t, p=2, dim=-1)
# Cosine similarity
similarity = torch.matmul(features, prototypes.T) # (B, C)
# Softmax with temperature
probs = F.softmax(similarity / self.tau, dim=-1)
return probs
def forward(self, features: torch.Tensor,
clip_probs: torch.Tensor) -> torch.Tensor:
"""
C-TTA forward pass
Args:
features: Visual features from CLIP encoder
clip_probs: Classification probabilities from CLIP
Returns:
Combined predictions
"""
# Update prototypes
self.update_prototype(features, clip_probs)
# Compute target prototype prediction
p_target = self.prototype_prediction(features)
# Combined prediction
p_final = clip_probs + self.lambda_ * p_target
return p_final与其他VLM TTA方法的关系
方法分类图
VLM测试时适应
├── 提示工程 (Prompt Tuning)
│ ├── CoOp: 可学习提示
│ ├── MaPLe: 多模态提示
│ └── ProDA: 动态提示
├── 测试时适应 (Test-Time Adaptation)
│ ├── BN统计更新: TENT
│ ├── 熵最小化: TPT, EATA
│ ├── 原型更新: BCA, C-TTA (本文)
│ └── 连续建模: C-TTA
└── 外部知识: CLIP引导
C-TTA的独特性
- 原型vs参数:更新原型而非模型参数
- 累积vs重置:持续累积知识而非每批次重置
- 无需标签:无需目标域类别标签
局限性
- 类别数限制:需要预先知道目标域类别数
- 原型初始化:依赖源域统计或预定义初始化
- 长尾分布:对极度不平衡的分布可能效果有限
- 概念漂移:对快速概念漂移的适应能力有限
总结
C-TTA的核心贡献是提出了一种针对VLM的连续测试时适应框架,通过目标原型机制实现了:
| 特性 | 改进 |
|---|---|
| 样本利用率 | 14% → 100% |
| 内存开销 | O(BND) → O(Cd) |
| 推理速度 | 基准 → 5.7×加速 |
| 准确率 | TCA 66.2% → 71.6% |
关键创新:
- 自适应融合权重
- 持续原型更新机制
- CLIP知识 + 目标知识的双重预测