StableTTA:免训练测试时适应
概述
StableTTA(arXiv:2604.04552)提出了一种无需训练的测试时适应方法,通过发现并解决聚合策略冲突问题,在ImageNet-1K上实现了显著的准确率提升。其核心贡献是发现了不同聚合策略之间的内在冲突,并提出**NSS(非显著抑制)**来解决这一问题。
核心发现
StableTTA的关键洞察是:当使用不同数据增强的logits进行聚合时,Hard Voting、Soft Voting和Logit Averaging三种策略之间存在冲突。这种冲突导致预测不一致,限制了集成方法的性能上限。
主要贡献
- 理论分析:基于Hölder条件分析了聚合冲突的来源
- NSS方法:提出非显著抑制操作,降低logit方差
- 稳定增强策略:改进Mixup和CutMix为确定性策略
- SOTA性能:在ImageNet-1K上使33个模型超过95%准确率,多个超过96%
问题背景:聚合策略冲突
三种聚合策略
在测试时适应中,常用的聚合策略包括:
| 聚合策略 | 公式 | 特点 |
|---|---|---|
| Hard Voting | 对logit尺度不敏感 | |
| Soft Voting | $\hat{y}{soft} = \arg\max_k \frac{1}{N}\sum{i=1}^{N} p_{\theta^{(i)}}(y=k | x)$ |
| Logit Averaging | 对logit尺度敏感 |
冲突来源
softmax和indicator函数都是非线性函数,导致不同聚合策略产生不一致的结果。
对于神经网络的前向传播,有Hölder条件:
对于数据增强后的logits:
关键发现:当方差 降低时,预测冲突概率降低。
核心方法
1. 非显著抑制(NSS)
NSS是StableTTA的核心操作,对logits进行方差降低处理:
其中:
- 是Top-K索引的指示向量
- 表示逐元素乘法
- 是保留的候选类别数
几何解释:将非Top-K位置的logits设置为当前最小值,实现”抑制”效果。
NSS的梯度性质
NSS的一个关键性质是使每个logit的梯度范数恒为1:
方差降低证明
基于高斯Poincaré不等式:
因此NSS有效降低方差,从而减少聚合冲突。
2. 稳定数据增强策略
StableTTA对传统的数据增强策略进行了关键修改:
| 方法 | 传统策略 | StableTTA策略 |
|---|---|---|
| Mixup | 与随机图像加权组合 | 与固定图像组合(从数据集中随机采样一次) |
| CutMix | 覆盖区域从Beta分布采样 | 固定窗口大小为图像的1/4 |
关键改进:从随机采样改为固定采样,减少增强的多样性,从而降低方差。
3. 最终推理流程
其中
算法流程
// StableTTA 核心伪代码
#include <bits/stdc++.h>
using namespace std;
vector<float> nss(const vector<float>& z, int K) {
int C = z.size();
// 找到Top-K索引
vector<pair<float, int>> indexed(C);
for (int i = 0; i < C; i++) indexed[i] = {z[i], i};
sort(indexed.begin(), indexed.end(), greater<pair<float, int>>());
float min_val = *min_element(z.begin(), z.end());
vector<float> result(C);
unordered_set<int> top_k;
for (int i = 0; i < K; i++) top_k.insert(indexed[i].second);
for (int i = 0; i < C; i++) {
result[i] = top_k.count(i) ? z[i] : min_val;
}
return result;
}
vector<float> stable_tta(model, const vector<float>& x, int N=32, int K=10) {
vector<vector<float>> logits_list;
for (int i = 0; i < N; i++) {
// 应用稳定的数据增强
auto x_aug = stable_augmentation(x, choice=['mixup', 'cutmix']);
// 前向传播
auto z = model.forward(x_aug);
logits_list.push_back(z);
}
// NSS处理
vector<vector<float>> processed;
for (auto& z : logits_list) {
processed.push_back(nss(z, K));
}
// 平均
vector<float> final_logit(C, 0.0f);
for (auto& p : processed) {
for (int i = 0; i < C; i++) final_logit[i] += p[i];
}
for (auto& v : final_logit) v /= N;
return final_logit;
}实验结果
ImageNet-1K主要结果
| 模型 | 基线Acc@1 | StableTTA Acc@1 | 提升 |
|---|---|---|---|
| AlexNet | 56.52% | 89.33% | +32.82% |
| MobileNet-V3-Small | 67.67% | 92.83% | +25.16% |
| ResNet50 | 76.13% | 94.52% | +18.39% |
| EfficientNet-B4 | 83.38% | 95.39% | +12.01% |
| ViT-B-16 | 81.07% | 95.85% | +14.78% |
| ConvNeXt-Base | 84.06% | 95.96% | +11.90% |
| Swin-Base | 84.53% | 96.04% | +11.51% |
关键发现
轻量级模型超越大型模型
StableTTA + MobileNetV3 超越了 ViT-B-16:
| 指标 | MobileNetV3 + StableTTA | ViT-B-16 |
|---|---|---|
| Top-1准确率 | +11.75% | - |
| 参数量 | -97.1% | 基准 |
| GFLOPs | -89.1% | 基准 |
超参数鲁棒性
推荐设置:N=32, K=10
| N | K=5 | K=10 | K=15 |
|---|---|---|---|
| 16 | 94.2% | 94.8% | 94.5% |
| 32 | 95.1% | 95.4% | 95.2% |
| 64 | 95.3% | 95.5% | 95.3% |
与其他免训练TTA方法的对比
| 方法 | 核心机制 | 准确率提升 | 计算开销 |
|---|---|---|---|
| StableTTA | NSS + 稳定增强 | +10-33% | N×前向 |
| TPT | 置信度选择 + 熵最小化 | +2-8% | N×前向 |
| Tent | 源-free BN统计 | +1-5% | 轻微额外计算 |
代码实现要点
PyTorch实现
import torch
import torch.nn.functional as F
def nss(logits: torch.Tensor, K: int) -> torch.Tensor:
"""
Non-Significant Suppression operation
logits: (batch_size, num_classes)
K: number of top candidates to keep
"""
batch_size, C = logits.shape
min_val = logits.min(dim=-1, keepdim=True).values
# Get top-K indices
_, top_k_idx = logits.topk(K, dim=-1)
# Create mask
mask = torch.zeros_like(logits, dtype=torch.bool)
mask.scatter_(1, top_k_idx, True)
# NSS operation
result = torch.where(mask, logits, min_val)
return result
def stable_tta(model, images, N=32, K=10, device='cuda'):
"""
StableTTA inference
"""
model.eval()
all_logits = []
with torch.no_grad():
for _ in range(N):
# Stable augmentation
aug_images = stable_mixup(images) # or stable_cutmix
logits = model(aug_images)
processed = nss(logits, K)
all_logits.append(processed)
# Average
avg_logits = torch.stack(all_logits).mean(dim=0)
return avg_logits
# Stable augmentation functions
def stable_mixup(images, alpha=0.4):
"""Mix with a fixed reference image"""
batch_size = images.shape[0]
# Sample a fixed mixing image (re-sampled per batch)
lam = torch.from_numpy(np.random.dirichlet([alpha, alpha]))
ref_idx = torch.randint(0, batch_size, (1,)).item()
mixed = lam[0] * images + lam[1] * images[ref_idx:ref_idx+1]
return mixed推荐配置
| 参数 | 推荐值 | 说明 |
|---|---|---|
| N | 32 | 增强次数 |
| K | 10 | Top-K候选数 |
| Mixup λ | Dirichlet(0.4, 0.4) | 混合比例 |
| CutMix比例 | 0.25 | 固定窗口大小 |
总结
StableTTA的核心贡献是发现了TTA中聚合策略冲突的问题,并通过NSS操作和稳定增强策略有效解决了这一问题。
关键要点:
- NSS降低方差:通过将非Top-K位置的logits设置为最小值
- 稳定增强减少多样性:从随机采样改为固定采样
- 轻量级超越重型:MobileNetV3 + StableTTA > ViT-B-16
- 无需训练:完全在测试时执行,无需梯度更新
代码仓库:https://github.com/LizhengMathAi/StableTTA