ViT-5:面向2020年代中期的现代Vision Transformer
概述
ViT-5(Vision Transformers for The Mid-2020s)是对原始Vision Transformer架构进行系统性现代化升级的研究工作。该工作并非提出全新的架构范式,而是基于过去五年的研究积累,对ViT的各个组件进行组件级(component-wise)精细化改进。
核心论文:arXiv:2602.080711
代码实现:GitHub - ViT-5
1. 背景与动机
Vision Transformer的发展历程
Vision Transformer(ViT)自2020年提出以来,已成为计算机视觉领域的基础架构。然而,原始ViT的设计在许多方面已经落后于最新的研究成果:
时间线:
2020 ── ViT ── 原始架构,Data-efficient Image Transformer
2021 ── DeiT ── 知识蒸馏,数据效率提升
2022 ── DeiT-II/III ── 训练策略优化
2023 ── DeiT-IV ── 持续改进
2024 ── 各种变体 ── Swin、ConvNeXt等
2026 ── ViT-5 ── 系统性现代化
现有改进的碎片化问题
过去几年,研究社区提出了大量改进ViT的方法:
| 改进方向 | 方法 | 发表年份 |
|---|---|---|
| 归一化 | LayerNorm → RMSNorm/GN | 2021-2023 |
| 激活函数 | GELU → SiLU/Swish | 2021-2022 |
| 位置编码 | 可学习 → RoPE/ALiBi | 2022-2023 |
| 门控机制 | 无 → 门控FFN | 2022-2023 |
| 令牌设计 | 单一CLS → 可学习令牌 | 2021-2023 |
问题:这些改进分散在不同论文中,缺乏统一的整合和系统评估。
ViT-5的目标
“While preserving the canonical Attention-FFN structure, we conduct a component-wise refinement”
ViT-5的核心理念:
- 保持ViT的简洁性——不引入复杂机制(如Swin的窗口移位)
- 系统整合——将过去五年的改进统一整合
- 组件级优化——对每个组件进行精细化改进
2. 核心改进:组件级现代化
2.1 归一化层现代化
原始ViT的问题
// 原始ViT使用标准LayerNorm
class OriginalViTNorm {
// LayerNorm实现
// 问题:计算开销较大,包含均值和方差计算
};ViT-5的改进
import torch
import torch.nn as nn
class ViT5NormConfig:
"""ViT-5归一化配置"""
# 选项1:RMSNorm(更高效)
# 移除均值计算,仅计算RMS
def rms_norm(x, normalized_shape, weight=None, eps=1e-6):
"""
RMSNorm:Root Mean Square Layer Normalization
优势:
- 移除均值计算,降低计算复杂度
- 30-40%的归一化层加速
- 在视觉任务上效果相当
"""
rms = torch.sqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + eps)
x_norm = x / rms
if weight is not None:
x_norm = weight * x_norm
return x_norm
# 选项2:GroupNorm(对小batch更稳定)
# 在通道维度上分组归一化
def group_norm(x, num_groups, num_channels, eps=1e-6):
"""
GroupNorm:分组归一化
优势:
- 对batch size不敏感
- 训练稳定性更好
- 适合视觉任务
"""
assert num_channels % num_groups == 0
x = x.view(x.size(0), num_groups, num_channels // num_groups, -1)
mean = x.mean(dim=[2, 3], keepdim=True)
var = x.var(dim=[2, 3], unbiased=False, keepdim=True)
x = (x - mean) / torch.sqrt(var + eps)
x = x.view(x.size(0), num_channels, *x.shape[2:])
return x2.2 激活函数现代化
原始ViT:GELU
# 原始ViT使用GELU
class OriginalActivation:
gelu = nn.GELU()
# GELU: x * Phi(x),其中Phi是标准正态CDF
# 计算开销较大,涉及erf函数ViT-5:更现代的激活
class ViT5Activation(nn.Module):
"""
ViT-5使用的现代激活函数
基于过去五年的实证研究选择
"""
def __init__(self, activation_type='silu'):
super().__init__()
self.activation_type = activation_type
if activation_type == 'silu':
# SiLU/Swish: x * sigmoid(x)
# 优势:自门控特性,训练更稳定
self.act = nn.SiLU()
elif activation_type == 'gelu_tanh':
# GELU-Tanh近似(用于加速)
self.act = nn.GELU(approximate='tanh')
else:
self.act = nn.GELU()
def forward(self, x):
return self.act(x)
# 门控激活机制
class GatedActivation(nn.Module):
"""
门控激活:f(x) = gate(x) * main(x)
优势:
- 信息流更可控
- 减少表示崩溃
- 更好的梯度流
"""
def __init__(self, d_model, d_ff, activation='silu'):
super().__init__()
self.gate_proj = nn.Linear(d_model, d_ff)
self.value_proj = nn.Linear(d_model, d_ff)
self.out_proj = nn.Linear(d_ff, d_model)
self.act = nn.SiLU() if activation == 'silu' else nn.GELU()
def forward(self, x):
gate = self.act(self.gate_proj(x))
value = self.value_proj(x)
return self.out_proj(gate * value)2.3 位置编码现代化
原始ViT的可学习位置编码
# 原始ViT使用可学习的位置编码
class OriginalViTPositionalEncoding:
def __init__(self, seq_len, d_model):
self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, d_model))
# 问题:位置编码长度固定,难以泛化到训练时未见过的长度ViT-5的改进
class ViT5PositionalEncoding(nn.Module):
"""
ViT-5位置编码:结合多种现代技术
"""
def __init__(self, d_model, max_len=2048, rope_type='rotary'):
super().__init__()
self.rope_type = rope_type
if rope_type == 'rotary':
# RoPE:旋转位置编码
# 优势:无需额外参数,可外推到更长序列
self.rope = RotaryPositionalEmbedding(d_model, max_len)
elif rope_type == 'alibi':
# ALiBi:注意力链式偏置
# 优势:无需位置编码,自然支持更长序列
self.alibi = ALiBiAttentionBias(d_model)
elif rope_type == 'learned':
# 可学习 + 插值策略
self.pos_embed = nn.Parameter(torch.zeros(1, max_len, d_model))
def forward(self, x, seq_len=None):
if self.rope_type == 'rotary':
return self.rope(x)
elif self.rope_type == 'alibi':
return self.alibi.get_bias(seq_len)
else:
return self.pos_embed[:, :seq_len]
class RotaryPositionalEmbedding(nn.Module):
"""
旋转位置编码(RoPE)
核心思想:将位置信息编码为旋转矩阵
"""
def __init__(self, dim, max_len=2048):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# 预计算旋转矩阵
t = torch.arange(max_len).type_as(self.inv_freq)
freqs = torch.einsum('n,i->ni', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached', emb.cos())
self.register_buffer('sin_cached', emb.sin())
@torch.jit.script
def rotate_half(x):
"""将x分成两半并旋转"""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def forward(self, q, k):
# 应用旋转
q_embed = (q * self.cos_cached[:q.shape[1]]) + \
(self.rotate_half(q) * self.sin_cached[:q.shape[1]])
k_embed = (k * self.cos_cached[:k.shape[1]]) + \
(self.rotate_half(k) * self.sin_cached[:k.shape[1]])
return q_embed, k_embed2.4 门控机制
class ViT5AttentionGate(nn.Module):
"""
注意力门控:增强信息流控制
"""
def __init__(self, d_model, n_heads):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(d_model, d_model),
nn.Sigmoid() # 门控值在(0, 1)
)
def forward(self, attn_output, hidden_state):
"""
门控后的注意力输出
"""
gate_values = self.gate(hidden_state)
return gate_values * attn_output
class ViT5FFNGate(nn.Module):
"""
FFN门控:减少表示崩溃
"""
def __init__(self, d_model, d_ff):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.SiLU(),
nn.Linear(d_ff, d_model),
nn.Sigmoid()
)
def forward(self, ffn_output, hidden_state):
return self.gate(hidden_state) * ffn_output2.5 可学习令牌设计
class ViT5LearnableTokens(nn.Module):
"""
ViT-5的可学习令牌设计
替代单一CLS token的改进
"""
def __init__(self, n_tokens, d_model):
super().__init__()
# 多个可学习令牌
self.cls_tokens = nn.Parameter(torch.randn(1, n_tokens, d_model))
# 可选的蒸馏令牌
self.dist_token = nn.Parameter(torch.randn(1, 1, d_model))
def forward(self, x):
"""
Args:
x: [batch, seq_len, d_model] - patch嵌入
Returns:
tokens: [batch, n_tokens + 1 + dist, d_model]
"""
batch_size = x.shape[0]
# 复制cls tokens到batch
cls_tokens = self.cls_tokens.expand(batch_size, -1, -1)
# 拼接
x = torch.cat([
cls_tokens, # [batch, n_tokens, d_model]
x, # [batch, seq_len, d_model]
], dim=1)
return x
def aggregate(self, x, n_tokens):
"""
聚合多个令牌的输出
"""
cls_out = x[:, :n_tokens] # [batch, n_tokens, d_model]
# 可使用平均、注意力加权等方式聚合
return cls_out.mean(dim=1) # 简单平均3. 完整ViT-5架构
import torch
import torch.nn as nn
from functools import partial
class ViT5Block(nn.Module):
"""ViT-5 Transformer块"""
def __init__(
self,
d_model: int,
n_heads: int,
d_ff: int = None,
dropout: float = 0.0,
drop_path: float = 0.0,
activation: str = 'silu',
norm_type: str = 'rmsnorm',
use_gate: bool = True,
):
super().__init__()
d_ff = d_ff or 4 * d_model
# 归一化
if norm_type == 'rmsnorm':
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)
elif norm_type == 'layernorm':
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
else:
self.norm1 = nn.GroupNorm(d_model // 32, d_model)
self.norm2 = nn.GroupNorm(d_model // 32, d_model)
# 注意力
self.attn = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True
)
# 门控
if use_gate:
self.attn_gate = ViT5AttentionGate(d_model, n_heads)
# FFN
self.ffn = GatedFFN(d_model, d_ff, activation)
if use_gate:
self.ffn_gate = ViT5FFNGate(d_model, d_ff)
# DropPath
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
# 注意力残差连接
h = self.norm1(x)
attn_out, _ = self.attn(h, h, h)
if hasattr(self, 'attn_gate'):
attn_out = self.attn_gate(attn_out, h)
x = x + self.drop_path(attn_out)
# FFN残差连接
h = self.norm2(x)
ffn_out = self.ffn(h)
if hasattr(self, 'ffn_gate'):
ffn_out = self.ffn_gate(ffn_out, h)
x = x + self.drop_path(ffn_out)
return x
class VisionTransformer5(nn.Module):
"""完整ViT-5模型"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
n_classes: int = 1000,
d_model: int = 768,
n_heads: int = 12,
n_layers: int = 12,
d_ff: int = None,
dropout: float = 0.0,
drop_path_rate: float = 0.1,
norm_type: str = 'rmsnorm',
activation: str = 'silu',
use_gate: bool = True,
n_learnable_tokens: int = 1,
use_rope: bool = True,
):
super().__init__()
# Patch嵌入
self.patch_embed = nn.Conv2d(
in_channels, d_model,
kernel_size=patch_size,
stride=patch_size
)
# 位置编码
n_patches = (img_size // patch_size) ** 2
if use_rope:
self.pos_encoding = None # 使用RoPE
else:
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, d_model))
# 可学习令牌
self.cls_tokens = ViT5LearnableTokens(n_learnable_tokens, d_model)
# Transformer块
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
self.blocks = nn.ModuleList([
ViT5Block(
d_model, n_heads, d_ff, dropout,
drop_path=dpr[i],
norm_type=norm_type,
activation=activation,
use_gate=use_gate,
)
for i in range(n_layers)
])
# 输出头
self.norm = RMSNorm(d_model) if norm_type == 'rmsnorm' else nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, n_classes)
# 初始化
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
# Patch嵌入
x = self.patch_embed(x).flatten(2).transpose(1, 2) # [B, N, D]
# 添加位置编码
if self.pos_embed is not None:
x = x + self.pos_embed
# 添加可学习令牌
x = self.cls_tokens(x)
# Transformer块
for block in self.blocks:
x = block(x)
# 输出
x = self.norm(x)
cls_out = self.cls_tokens.aggregate(x, n_tokens=self.cls_tokens.cls_tokens.shape[1])
return self.head(cls_out)4. 训练配置
4.1 优化器与学习率
class ViT5TrainingConfig:
"""
ViT-5训练配置
基于大规模视觉模型训练的最佳实践
"""
# 优化器:FusedLAMB
optimizer = {
'type': 'FusedLAMB',
'lr': 1e-3,
'weight_decay': 0.05,
'beta1': 0.9,
'beta2': 0.999,
'eps': 1e-8,
}
# 学习率调度
scheduler = {
'type': 'CosineAnnealingLR',
'T_max': 300,
'eta_min': 1e-6,
}
# 数据增强
augmentation = {
'mixup_alpha': 0.8,
'cutmix_alpha': 1.0,
'color_jitter': 0.3,
'three_augment': True,
'repeated_aug': True, # 重复增强
}
# 正则化
regularization = {
'drop_path': 0.05, # 小模型
# 0.35 for large model
'label_smoothing': 0.0, # 预训练
'label_smoothing': 0.1, # 微调
}
# 其他
warmup_epochs = 5
epochs = 300
batch_size = 4096 # 大batch训练4.2 训练策略对比
| 配置 | 原始ViT | DeiT | ViT-5 |
|---|---|---|---|
| 优化器 | AdamW | AdamW | FusedLAMB |
| Batch Size | 4096 | 1024 | 4096 |
| 权重衰减 | 0.1 | 0.05 | 0.05 |
| 标签平滑 | 0.0 | 0.1 | 0.0/0.1 |
| DropPath | 0.0 | 0.0-0.1 | 0.05-0.35 |
5. 实验结果
5.1 模型变体
| 模型 | 分辨率 | 参数量 | ImageNet-1K Top-1 |
|---|---|---|---|
| ViT-5-S | 224 | 22M | 82.2% |
| ViT-5-B | 224 | 87M | 84.2% |
| ViT-5-B | 384 | 87M | 85.4% |
| ViT-5-L | 224 | 304M | 84.9% |
| ViT-5-L | 384 | 304M | 86.0% |
5.2 与Swin Transformer对比
| 模型 | Swin-B | ViT-5-B | Swin-L | ViT-5-L |
|---|---|---|---|---|
| 参数量 | 88M | 87M | 196M | 304M |
| ImageNet | 83.8% | 84.2% | 85.2% | 85.5% |
| 推理速度 | 基准 | 1.1x | 基准 | 0.9x |
5.3 下游任务
| 任务 | Swin-B | ViT-5-B |
|---|---|---|
| COCO Detection | 48.5 mAP | 49.1 mAP |
| ADE20K Segmentation | 48.0 mIoU | 48.5 mIoU |
6. 与ConvNeXt的对比
架构哲学差异
| 方面 | ViT-5 | ConvNeXt |
|---|---|---|
| 核心机制 | 全局自注意力 | 大核卷积 |
| 设计理念 | 保持ViT简洁性 | 将CNN现代化 |
| 窗口机制 | 无(全局注意力) | 无(纯卷积) |
| 复杂性 | 低 | 中 |
| 可扩展性 | 高 | 高 |
适用场景
def choose_architecture():
"""
ViT-5 vs ConvNeXt 选择指南
"""
# 选择ViT-5的场景:
# - 需要全局感受野
# - 任务涉及长距离依赖
# - 喜欢简洁架构
# - 需要与ViT预训练模型兼容
vit5_scenarios = [
"图像分类(高分辨率)",
"语义分割",
"多尺度目标检测",
"需要预训练ViT迁移"
]
# 选择ConvNeXt的场景:
# - 硬件对卷积友好
# - 需要高效率推理
# - 密集预测任务
convnext_scenarios = [
"实时应用",
"边缘设备部署",
"工业检测",
"视频理解"
]
return vit5_scenarios, convnext_scenarios7. 总结
ViT-5的核心启示
- 渐进式改进的有效性:不需要全新的架构范式,系统性的组件升级同样有效
- 保持简洁性:避免引入不必要的复杂性(如Swin的窗口机制)
- 研究积累的价值:将分散的改进整合可以带来显著提升
- ViT的持久生命力:经过现代化升级,ViT仍是最先进的架构之一
未来方向
- 更高效的注意力机制
- 自适应计算分配
- 多模态融合
参考资料
相关专题:ViT详解 | Swin Transformer | ConvNeXt
Footnotes
-
Wang et al. (2026). ViT-5: Vision Transformers for The Mid-2020s. arXiv:2602.08071 ↩