C-TTA:连续测试时适应

概述

C-TTA(Continuous Test-Time Adaptation)针对视觉语言模型(VLM)提出了一种连续测试时适应框架。其核心创新在于通过持续更新的目标原型来适应整个目标域分布,解决了现有方法只处理静态分布或小部分样本的问题。

核心贡献

  1. 无需回传:仅更新目标原型,无需反向传播
  2. 无需大缓存:仅存储C×d维的原型矩阵(d为特征维度)
  3. 全样本利用:通过持续更新机制利用所有测试样本
  4. 多模态扩展:可扩展至3D VLM(点云分析)

问题定义

现有TTA方法的局限性:

方法类型代表方法局限性
Instance-wise TTATPT, EATA假设静态目标域分布
Episodic-wise TTATDA, BCA仅使用高置信度小部分样本(~14%)
Continuous TTAC-TTA利用所有样本累积知识

方法详解

1. 问题形式化

给定预训练的VLM(如CLIP)和无标签目标数据分布 ,目标是最大化:

C-TTA的核心假设是:目标原型可以累积来自所有测试样本的领域知识

2. 核心机制

目标原型定义

其中:

  • :类别数
  • :视觉特征维度
  • :时间步t时类别j的目标原型

自适应融合权重

其中 是类别置信度, 是温度参数。

物理意义

  • 很高(模型对类别j很自信)时,,完全信任当前预测
  • 很低(模型不确定)时,,更多依赖历史原型

原型更新规则

其中 是当前图像的视觉特征。

最终预测

其中:

3. 算法流程

# C-TTA 核心算法
def c_tta(model, target_images, C, d, lambda_=1.0, h=0.1, tau=0.07):
    """
    Continuous Test-Time Adaptation for VLM
    
    Args:
        model: Pre-trained CLIP model
        target_images: Stream of test images
        C: Number of classes
        d: Feature dimension
        lambda_: Balance weight for target prototype
        h: Temperature for confidence weighting
        tau: Temperature for prototype similarity
    """
    # Initialize target prototypes with source statistics
    P_t = initialize_prototypes(model, C, d)
    
    for x in target_images:
        # Extract visual features
        F_g = model.encode_image(x)  # (1, d)
        
        # Compute CLIP prediction
        p_clip = model.predict(x)  # (1, C)
        
        # Compute adaptive weights
        beta = 1 - torch.exp(-p_clip / h)  # (1, C)
        
        # Update prototypes
        for j in range(C):
            P_t[j] = (1 - beta[j]) * P_t[j] + beta[j] * F_g
        
        # Compute target prototype prediction
        p_target = prototype_prediction(F_g, P_t, tau)
        
        # Combined prediction
        p_final = p_clip + lambda_ * p_target
        
    return P_t, p_final

技术优势分析

与现有方法对比

维度TPTTDABCAC-TTA
适应单元单样本小批次小批次全分布
参数更新BN统计全部BN统计原型
回传需求
缓存需求
内存开销O(1)O(BN×D)O(BN×D)O(C×d)
样本利用率100%~14%~14%100%
推理加速0.3×0.4×5.7×

内存复杂度分析

假设:

  • (批次大小)
  • (ViT的patch数)
  • (隐藏维度)
方法内存开销C-TTA压缩比
TDA150×
BCA150×
C-TTA

实验结果

Cross-Dataset泛化(ImageNet→其他数据集)

方法IN-AIN-RIN-V2平均
CLIP (基线)56.2%73.4%67.3%65.6%
TPT59.8%75.1%69.2%68.0%
TDA61.3%76.2%70.1%69.2%
BCA62.1%76.8%70.5%69.8%
C-TTA64.5%78.2%72.1%71.6%

Domain Generalization(OOD设置)

方法ImageNet-C (平均)CIFAR-C平均OOD
CLIP52.1%62.4%57.3%
TPT58.7%68.2%63.5%
TCA61.2%71.1%66.2%
C-TTA64.8%73.6%69.2%

3D VLM扩展(点云分析)

数据集任务CLIP基线C-TTA
ScanObjectNN分类48.2%63.1%
ModelNet40分类56.7%71.3%
ShapeNetPart分割62.4%75.8%

效率对比

方法GPU显存推理时间(相对)
TDA8.2GB5.7×
TPT4.1GB1.2×
C-TTA3.8GB1.0×

消融实验

原型更新机制

更新策略公式效果
固定权重68.2%
置信度加权(本文)71.6%
纯插值67.8%

温度参数h的影响

h值IN-A准确率稳定性
0.0162.1%
0.164.5%
0.563.8%
1.061.2%

推荐值

类别数C的影响

C-TTA在不同类别数下的表现:

CIN-AIN-RIN-V2
10063.2%77.5%71.0%
50064.1%77.9%71.5%
100064.5%78.2%72.1%

PyTorch实现

import torch
import torch.nn.functional as F
from torch import nn
 
class CTTA(nn.Module):
    """
    Continuous Test-Time Adaptation for Vision-Language Models
    """
    def __init__(self, feature_dim: int, num_classes: int, 
                 lambda_: float = 1.0, h: float = 0.1, tau: float = 0.07):
        super().__init__()
        self.num_classes = num_classes
        self.feature_dim = feature_dim
        self.lambda_ = lambda_
        self.h = h
        self.tau = tau
        
        # Target prototypes (trainable buffers, not parameters)
        self.register_buffer('P_t', torch.zeros(num_classes, feature_dim))
        self.initialized = False
    
    def initialize_prototypes(self, model, loader):
        """Initialize prototypes using source data statistics"""
        features = []
        for images, _ in loader:
            with torch.no_grad():
                feat = model.encode_image(images)
                features.append(feat)
        features = torch.cat(features, dim=0)
        
        # Initialize with class-wise mean features
        # This is a simplified version; actual implementation 
        # would use class labels from source data
        pass
    
    def update_prototype(self, features: torch.Tensor, 
                         clip_probs: torch.Tensor):
        """
        Update target prototypes based on current sample
        
        Args:
            features: (batch_size, feature_dim) visual features
            clip_probs: (batch_size, num_classes) CLIP predictions
        """
        batch_size = features.shape[0]
        
        # Adaptive weight: 1 - exp(-s/h)
        beta = 1 - torch.exp(-clip_probs / self.h)  # (B, C)
        
        # Expand dimensions for broadcasting
        beta = beta.unsqueeze(-1)  # (B, C, 1)
        features = features.unsqueeze(1)  # (B, 1, d)
        
        # Update prototypes (in-place for test-time)
        with torch.no_grad():
            update = (beta * features).mean(dim=0)  # (C, d)
            self.P_t.mul_(1 - beta.mean(dim=0).unsqueeze(-1))
            self.P_t.add_(update)
    
    def prototype_prediction(self, features: torch.Tensor) -> torch.Tensor:
        """
        Compute prediction based on target prototypes
        
        Args:
            features: (batch_size, feature_dim)
            
        Returns:
            probs: (batch_size, num_classes)
        """
        # Normalize features and prototypes
        features = F.normalize(features, p=2, dim=-1)
        prototypes = F.normalize(self.P_t, p=2, dim=-1)
        
        # Cosine similarity
        similarity = torch.matmul(features, prototypes.T)  # (B, C)
        
        # Softmax with temperature
        probs = F.softmax(similarity / self.tau, dim=-1)
        return probs
    
    def forward(self, features: torch.Tensor, 
                clip_probs: torch.Tensor) -> torch.Tensor:
        """
        C-TTA forward pass
        
        Args:
            features: Visual features from CLIP encoder
            clip_probs: Classification probabilities from CLIP
            
        Returns:
            Combined predictions
        """
        # Update prototypes
        self.update_prototype(features, clip_probs)
        
        # Compute target prototype prediction
        p_target = self.prototype_prediction(features)
        
        # Combined prediction
        p_final = clip_probs + self.lambda_ * p_target
        
        return p_final

与其他VLM TTA方法的关系

方法分类图

VLM测试时适应
├── 提示工程 (Prompt Tuning)
│   ├── CoOp: 可学习提示
│   ├── MaPLe: 多模态提示
│   └── ProDA: 动态提示
├── 测试时适应 (Test-Time Adaptation)
│   ├── BN统计更新: TENT
│   ├── 熵最小化: TPT, EATA
│   ├── 原型更新: BCA, C-TTA (本文)
│   └── 连续建模: C-TTA
└── 外部知识: CLIP引导

C-TTA的独特性

  1. 原型vs参数:更新原型而非模型参数
  2. 累积vs重置:持续累积知识而非每批次重置
  3. 无需标签:无需目标域类别标签

局限性

  1. 类别数限制:需要预先知道目标域类别数
  2. 原型初始化:依赖源域统计或预定义初始化
  3. 长尾分布:对极度不平衡的分布可能效果有限
  4. 概念漂移:对快速概念漂移的适应能力有限

总结

C-TTA的核心贡献是提出了一种针对VLM的连续测试时适应框架,通过目标原型机制实现了:

特性改进
样本利用率14% → 100%
内存开销O(BND) → O(Cd)
推理速度基准 → 5.7×加速
准确率TCA 66.2% → 71.6%

关键创新

  1. 自适应融合权重
  2. 持续原型更新机制
  3. CLIP知识 + 目标知识的双重预测

参考