MORPHEUS:基于神经崩溃几何的元测试时适应

概述

MORPHEUS(Meta Test-Time Adaptation via Neural Collapse Geometry)是首个元测试时适应框架,解决了一个关键问题:如何不执行实际适应就能预测哪种TTA方法最有效?

核心问题

现有TTA方法存在两个核心问题:

  1. 方法选择困境:不同TTA方法在不同数据分布上表现差异巨大,没有统一最优方法
  2. 计算成本:逐一尝试所有TTA方法的成本过高

核心洞察

嵌入空间的神经崩溃(Neural Collapse)几何特征可以预测TTA后的准确率

MORPHEUS发现:

  • 类别预测的softmax熵
  • 嵌入空间的神经崩溃几何特征

这两类特征可以预测最佳TTA方法和适应后的准确率。


神经崩溃理论

Neural Collapse现象

神经崩溃是深度网络训练后期观察到的几何现象1

阶段现象
NC1变量坍缩:同类特征的方差趋近于零
NC2收敛到Simplex ETF:类均值形成等角紧框架
NC3自对偶收敛:类均值与最后一层权重相互收敛
NC4简化决策:决策等价于最近类中心

Simplex ETF结构

在训练后期,网络学习到的特征满足:

其中:

  • :类别i的类均值
  • :最后一层权重向量
  • :类别数

这意味着类均值均匀分布在超球面上,形成等角紧框架(ETF)


方法详解

问题形式化

RQ1: 给定源模型 和无标签目标数据 ,从 个TTA方法中选择会产生最高准确率的方法

RQ2: 预测使用TTA方法 适应后的准确率

MORPHEUS框架

输入: 源模型 f_s + 目标数据 Ẋ
       ↓
1. 提取源模型在Ẋ上的特征:
   - Softmax熵: H
   - 嵌入向量: h_Ẋ
       ↓
2. 计算Neural Collapse几何特征 (5个)
       ↓
3. 使用回归模型预测各TTA方法的相对排名和绝对准确率
       ↓
4. 选择预测准确率最高的TTA方法
       ↓
输出: 最佳TTA方法 A* 和预测准确率

Neural Collapse几何特征

MORPHEUS提取5个NC几何特征:

特征1:类内紧凑度 (Within-Class Spread)

其中 是类别c的所有嵌入向量。

特征2:类间分离度 (Between-Class Spread)

特征3:类中心对齐度 (ETF Alignment)

理想情况下

特征4:决策边界距离 (Decision Boundary Distance)

特征5:特征空间曲率 (Feature Space Curvature)

回归模型

使用随机森林回归器进行预测:

  • 输入:
  • 输出1:各TTA方法的相对排名
  • 输出2:适应后的绝对准确率
# MORPHEUS 核心伪代码
def morpheus(source_model, target_data, tta_methods):
    # 1. 提取特征
    features = extract_features(source_model, target_data)  # (n, d)
    probs = source_model.predict(target_data)  # (n, C)
    entropy = compute_entropy(probs)  # (n,)
    
    # 2. 计算NC几何特征
    nc_features = compute_neural_collapse_features(features, probs)
    
    # 3. 预测各TTA方法的准确率
    predicted_accuracies = {}
    for method in tta_methods:
        pred = regressor.predict([entropy.mean(), *nc_features], method)
        predicted_accuracies[method] = pred
    
    # 4. 选择最佳方法
    best_method = max(predicted_accuracies, key=lambda k: predicted_accuracies[k])
    
    return best_method, predicted_accuracies

实验结果

ImageNet-C基准(ResNet50-BN, severity=5)

方法平均准确率标准差最高准确率
Source Only62.9%--
TENT63.1%0.968.2%
EATA63.5%1.269.1%
ROID (Prior SOTA)64.2%0.569.8%
MORPHEUS (NC)65.2%0.970.5%

关键发现

1. NC特征优于纯熵特征

特征组合预测RMSE说明
纯熵0.086基线
NC几何0.054-37%
NC + 熵0.052最佳

2. 防止灾难性失败

MORPHEUS可以识别会导致灾难性失败的TTA方法:

腐败类型源模型TENTEATAMORPHEUS选择
Gaussian Blur62.9%63.1%64.2%EATA (正确)
Defocus Blur61.2%58.3% ❌64.5%EATA (正确)
Motion Blur60.8%61.1%59.8% ❌TENT (正确)

❌ 表示该方法在该类型上表现恶化。

3. 稳定性提升

指标Prior SOTAMORPHEUS改进
平均准确率62.9%65.2%+2.3%
标准差0.90.7-22%
灾难失败次数123-75%

理论分析

为什么NC几何特征有效?

1. 域偏移与NC的关系

域偏移会影响NC的几何结构:

域偏移程度NC特征变化TTA效果预测
轻微 略增, 略降熵最小化有效
中等 下降, 变化稳定性方法更好
严重NC结构被破坏需选择鲁棒方法

2. TTA方法与NC特征匹配

TTA方法最适NC特征原理
TENT (熵最小化), 低 网络仍保持ETF结构
EATA (熵+多样性)中等 需要正则化防止过拟合
ROID (稳定性), 高 网络结构已破坏

实现细节

PyTorch实现

import torch
import torch.nn.functional as F
import numpy as np
from sklearn.ensemble import RandomForestRegressor
 
class MORPHEUS:
    """
    Meta Test-Time Adaptation via Neural Collapse Geometry
    """
    def __init__(self, num_classes, regressor=None):
        self.num_classes = num_classes
        self.regressor = regressor or RandomForestRegressor(
            n_estimators=100, max_depth=5
        )
        self.tta_methods = ['tent', 'eata', 'roid', 'tpt']
        self.fitted = False
    
    def extract_features(self, model, dataloader, device='cuda'):
        """Extract features and predictions from source model"""
        model.eval()
        all_features = []
        all_probs = []
        
        with torch.no_grad():
            for images, _ in dataloader:
                images = images.to(device)
                features = model.forward_features(images)
                probs = F.softmax(model(images), dim=-1)
                all_features.append(features)
                all_probs.append(probs)
        
        features = torch.cat(all_features, dim=0)
        probs = torch.cat(all_probs, dim=0)
        entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
        
        return features, probs, entropy
    
    def compute_neural_collapse_features(self, features, probs):
        """
        Compute 5 Neural Collapse geometric features
        """
        C = self.num_classes
        n, d = features.shape
        
        # Get predicted classes
        pred_classes = probs.argmax(dim=-1)
        
        # Feature 1: Within-class spread
        within_spread = 0
        for c in range(C):
            class_mask = (pred_classes == c)
            if class_mask.sum() > 1:
                class_features = features[class_mask]
                within_spread += class_features.var(dim=0).mean()
        f1 = within_spread / C
        
        # Feature 2: Between-class spread
        class_means = []
        for c in range(C):
            class_mask = (pred_classes == c)
            class_mean = features[class_mask].mean(dim=0)
            class_means.append(class_mean)
        class_means = torch.stack(class_means)  # (C, d)
        
        between_spread = 0
        for i in range(C):
            for j in range(i+1, C):
                between_spread += F.mse_loss(class_means[i], class_means[j])
        f2 = between_spread / (C * (C-1) / 2)
        
        # Feature 3: ETF alignment (simplified)
        f3 = 0.5  # Placeholder, actual implementation uses last layer weights
        
        # Feature 4: Decision boundary distance
        decision_dist = float('inf')
        for c in range(C):
            class_mask = (pred_classes == c)
            if class_mask.sum() > 0:
                class_mean = features[class_mask].mean(dim=0)
                dist = class_mean.norm()
                decision_dist = min(decision_dist, dist.item())
        f4 = decision_dist
        
        # Feature 5: Feature space curvature (simplified)
        f5 = features.norm(dim=-1).std().item()
        
        return [f1.item(), f2.item(), f3, f4, f5]
    
    def predict_best_tta(self, features, probs, entropy):
        """
        Predict best TTA method without actual adaptation
        """
        # Compute NC features
        nc_features = self.compute_neural_collapse_features(features, probs)
        
        # Build input vector
        X = [[entropy.mean().item()] + nc_features]
        
        # Predict for each method
        predictions = {}
        for method in self.tta_methods:
            predictions[method] = self.regressor.predict(X)[0]
        
        # Select best method
        best_method = max(predictions, key=lambda k: predictions[k])
        
        return best_method, predictions
    
    def fit(self, train_data):
        """
        Fit the regression model on training data
        
        train_data: List of (nc_features, entropy, method, accuracy) tuples
        """
        X = []
        y = []
        
        for nc_feat, ent, method, acc in train_data:
            X.append([ent] + nc_feat)
            y.append(acc)
        
        self.regressor.fit(X, y)
        self.fitted = True

训练数据收集

MORPHEUS需要以下训练数据:

# 数据收集伪代码
def collect_training_data(source_model, corruption_benchmarks):
    """
    Collect training data for MORPHEUS regression model
    """
    train_data = []
    
    for benchmark in corruption_benchmarks:
        for corruption_type in benchmark.types:
            # Extract features and NC geometry
            features, probs, entropy = morpheus.extract_features(
                source_model, benchmark[corruption_type]
            )
            nc_features = morpheus.compute_neural_collapse_features(
                features, probs
            )
            
            # Evaluate all TTA methods
            for method in tta_methods:
                acc = evaluate_tta(source_model, method, benchmark[corruption_type])
                train_data.append((nc_features, entropy.mean(), method, acc))
    
    return train_data

与其他元TTA方法的对比

维度MORPHEUSMetaTTAAdaTTA
预测机制NC几何特征元学习在线学习
是否实际适应部分
计算开销O(n)O(k×n)O(n)
准确性SOTA中等较好
可解释性✅ NC理论❌ 黑盒❌ 黑盒

局限性与未来方向

局限性

  1. 需要训练数据:需要收集多样域偏移下的TTA表现数据
  2. 回归模型限制:简单随机森林可能无法捕捉复杂关系
  3. NC特征近似:部分NC特征需要访问最后一层权重
  4. 方法覆盖:无法预测新提出的TTA方法

未来方向

  1. 自适应NC特征:根据网络架构自适应提取NC特征
  2. 深度回归模型:使用神经网络替代随机森林
  3. 在线更新:在部署中持续更新回归模型
  4. 跨架构迁移:研究NC特征的跨架构迁移性

总结

MORPHEUS的核心贡献是提出了首个元测试时适应框架,通过神经崩溃几何特征实现无需实际执行适应即可选择最佳TTA方法。

关键创新

创新点描述
NC几何特征5个几何特征捕捉域偏移的本质
无需适应单次前向传播即可决策
防止灾难失败通过预测避免性能退化
可解释性基于NC理论的明确解释

性能提升

  • 平均准确率:Prior SOTA 64.2% → 65.2%
  • 标准差:0.9 → 0.7 (-22%)
  • 灾难失败:12次 → 3次 (-75%)

参考

Footnotes

  1. Papyan V, Han X Y, Donoho D L. Prevalence of Neural Collapse during the terminal phase of deep learning training. PNAS, 2020.