StableTTA:免训练测试时适应

概述

StableTTA(arXiv:2604.04552)提出了一种无需训练的测试时适应方法,通过发现并解决聚合策略冲突问题,在ImageNet-1K上实现了显著的准确率提升。其核心贡献是发现了不同聚合策略之间的内在冲突,并提出**NSS(非显著抑制)**来解决这一问题。

核心发现

StableTTA的关键洞察是:当使用不同数据增强的logits进行聚合时,Hard Voting、Soft Voting和Logit Averaging三种策略之间存在冲突。这种冲突导致预测不一致,限制了集成方法的性能上限。

主要贡献

  1. 理论分析:基于Hölder条件分析了聚合冲突的来源
  2. NSS方法:提出非显著抑制操作,降低logit方差
  3. 稳定增强策略:改进Mixup和CutMix为确定性策略
  4. SOTA性能:在ImageNet-1K上使33个模型超过95%准确率,多个超过96%

问题背景:聚合策略冲突

三种聚合策略

在测试时适应中,常用的聚合策略包括:

聚合策略公式特点
Hard Voting对logit尺度不敏感
Soft Voting$\hat{y}{soft} = \arg\max_k \frac{1}{N}\sum{i=1}^{N} p_{\theta^{(i)}}(y=kx)$
Logit Averaging对logit尺度敏感

冲突来源

softmax和indicator函数都是非线性函数,导致不同聚合策略产生不一致的结果。

对于神经网络的前向传播,有Hölder条件:

对于数据增强后的logits:

关键发现:当方差 降低时,预测冲突概率降低。


核心方法

1. 非显著抑制(NSS)

NSS是StableTTA的核心操作,对logits进行方差降低处理:

其中:

  • 是Top-K索引的指示向量
  • 表示逐元素乘法
  • 是保留的候选类别数

几何解释:将非Top-K位置的logits设置为当前最小值,实现”抑制”效果。

NSS的梯度性质

NSS的一个关键性质是使每个logit的梯度范数恒为1:

方差降低证明

基于高斯Poincaré不等式

因此NSS有效降低方差,从而减少聚合冲突。

2. 稳定数据增强策略

StableTTA对传统的数据增强策略进行了关键修改:

方法传统策略StableTTA策略
Mixup与随机图像加权组合与固定图像组合(从数据集中随机采样一次)
CutMix覆盖区域从Beta分布采样固定窗口大小为图像的1/4

关键改进:从随机采样改为固定采样,减少增强的多样性,从而降低方差。

3. 最终推理流程

其中


算法流程

// StableTTA 核心伪代码
#include <bits/stdc++.h>
using namespace std;
 
vector<float> nss(const vector<float>& z, int K) {
    int C = z.size();
    // 找到Top-K索引
    vector<pair<float, int>> indexed(C);
    for (int i = 0; i < C; i++) indexed[i] = {z[i], i};
    sort(indexed.begin(), indexed.end(), greater<pair<float, int>>());
    
    float min_val = *min_element(z.begin(), z.end());
    vector<float> result(C);
    unordered_set<int> top_k;
    
    for (int i = 0; i < K; i++) top_k.insert(indexed[i].second);
    for (int i = 0; i < C; i++) {
        result[i] = top_k.count(i) ? z[i] : min_val;
    }
    return result;
}
 
vector<float> stable_tta(model, const vector<float>& x, int N=32, int K=10) {
    vector<vector<float>> logits_list;
    
    for (int i = 0; i < N; i++) {
        // 应用稳定的数据增强
        auto x_aug = stable_augmentation(x, choice=['mixup', 'cutmix']);
        // 前向传播
        auto z = model.forward(x_aug);
        logits_list.push_back(z);
    }
    
    // NSS处理
    vector<vector<float>> processed;
    for (auto& z : logits_list) {
        processed.push_back(nss(z, K));
    }
    
    // 平均
    vector<float> final_logit(C, 0.0f);
    for (auto& p : processed) {
        for (int i = 0; i < C; i++) final_logit[i] += p[i];
    }
    for (auto& v : final_logit) v /= N;
    
    return final_logit;
}

实验结果

ImageNet-1K主要结果

模型基线Acc@1StableTTA Acc@1提升
AlexNet56.52%89.33%+32.82%
MobileNet-V3-Small67.67%92.83%+25.16%
ResNet5076.13%94.52%+18.39%
EfficientNet-B483.38%95.39%+12.01%
ViT-B-1681.07%95.85%+14.78%
ConvNeXt-Base84.06%95.96%+11.90%
Swin-Base84.53%96.04%+11.51%

关键发现

轻量级模型超越大型模型

StableTTA + MobileNetV3 超越了 ViT-B-16:

指标MobileNetV3 + StableTTAViT-B-16
Top-1准确率+11.75%-
参数量-97.1%基准
GFLOPs-89.1%基准

超参数鲁棒性

推荐设置:N=32, K=10

NK=5K=10K=15
1694.2%94.8%94.5%
3295.1%95.4%95.2%
6495.3%95.5%95.3%

与其他免训练TTA方法的对比

方法核心机制准确率提升计算开销
StableTTANSS + 稳定增强+10-33%N×前向
TPT置信度选择 + 熵最小化+2-8%N×前向
Tent源-free BN统计+1-5%轻微额外计算

代码实现要点

PyTorch实现

import torch
import torch.nn.functional as F
 
def nss(logits: torch.Tensor, K: int) -> torch.Tensor:
    """
    Non-Significant Suppression operation
    logits: (batch_size, num_classes)
    K: number of top candidates to keep
    """
    batch_size, C = logits.shape
    min_val = logits.min(dim=-1, keepdim=True).values
    
    # Get top-K indices
    _, top_k_idx = logits.topk(K, dim=-1)
    
    # Create mask
    mask = torch.zeros_like(logits, dtype=torch.bool)
    mask.scatter_(1, top_k_idx, True)
    
    # NSS operation
    result = torch.where(mask, logits, min_val)
    return result
 
def stable_tta(model, images, N=32, K=10, device='cuda'):
    """
    StableTTA inference
    """
    model.eval()
    all_logits = []
    
    with torch.no_grad():
        for _ in range(N):
            # Stable augmentation
            aug_images = stable_mixup(images)  # or stable_cutmix
            logits = model(aug_images)
            processed = nss(logits, K)
            all_logits.append(processed)
    
    # Average
    avg_logits = torch.stack(all_logits).mean(dim=0)
    return avg_logits
 
# Stable augmentation functions
def stable_mixup(images, alpha=0.4):
    """Mix with a fixed reference image"""
    batch_size = images.shape[0]
    # Sample a fixed mixing image (re-sampled per batch)
    lam = torch.from_numpy(np.random.dirichlet([alpha, alpha]))
    ref_idx = torch.randint(0, batch_size, (1,)).item()
    mixed = lam[0] * images + lam[1] * images[ref_idx:ref_idx+1]
    return mixed

推荐配置

参数推荐值说明
N32增强次数
K10Top-K候选数
Mixup λDirichlet(0.4, 0.4)混合比例
CutMix比例0.25固定窗口大小

总结

StableTTA的核心贡献是发现了TTA中聚合策略冲突的问题,并通过NSS操作和稳定增强策略有效解决了这一问题。

关键要点

  1. NSS降低方差:通过将非Top-K位置的logits设置为最小值
  2. 稳定增强减少多样性:从随机采样改为固定采样
  3. 轻量级超越重型:MobileNetV3 + StableTTA > ViT-B-16
  4. 无需训练:完全在测试时执行,无需梯度更新

代码仓库https://github.com/LizhengMathAi/StableTTA


参考